{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we'll be implementing the Transformer-XL architecture from scratch.\n",
    "To run this notebook in full, make sure to use the `download_data.sh` script to download the Penn Treebank data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu\") if not torch.cuda.is_available() else torch.device(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Review of the Transformer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start off with a quick overview of the Transformer architecture which will be the basis of the Transformer XL.\n",
    "\n",
    "Overall, the Transformer architecture is composed of multiple MultiHeadAttention layer stacked on top of each other, followed by feedforward layers, residual connections, and layer normalization layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://camo.githubusercontent.com/88e8f36ce61dedfd2491885b8df2f68c4d1f92f5/687474703a2f2f696d6775722e636f6d2f316b72463252362e706e67)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The MultiHeadAttention layer is composed of multiple attention heads. Each attention head applies a linear transformation to its inputs and computes attention over its input values using keys and queries."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/multi_head_attention.png?zoom=2&resize=224%2C293)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This approach is incapable of handling position, so the Transformer adds embeddings representing the position of the input to the word embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more details, please refer to [this tutorial](https://github.com/keitakurita/Practical_NLP_in_PyTorch/blob/master/deep_dives/transformer_from_scratch.ipynb) I've written on just the Transformer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be building the Transformer XL by first implementing a single attention head, scaling it to compose the MultiHeadAttention layer for the Transformer XL, then building the DecoderBlock and stacking them to create the full Transformer XL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing the Transformer XL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A Single Attention Head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start off by implementing a single attention head in a MultiHeadAttention layer. To make things concrete, let's consider the first layer and assume we receive an input of word embeddings of shape `(seq=7, batch_size=3, embedding_dim=32)`. Note that the Transformer XL does not add positional embeddings to the input."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq, batch_size, embedding_dim = 7, 3, 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_embs = torch.rand(seq, batch_size, embedding_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the Transformer XL, we also feed the cached outputs of the model for the previous sequence. In this case, we would be feeding the word embeddings from the previous sequence as an additional input to our model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://4.bp.blogspot.com/-Do42uKiMvKg/XFCns7oXi5I/AAAAAAAADuc/ZS-p1XHZUNo3K9wv6nRG5AmdEK7mJsrugCLcBGAs/s1600/xl-eval.gif)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make things concrete, let's imagine our previous sequence was of length `prev_seq=6`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "prev_seq = 6\n",
    "memory = torch.rand(prev_seq, batch_size, embedding_dim) # hidden states from the previous sequence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Each attention head takes keys, queries, and values as input. The processing goes like this:\n",
    "\n",
    "1. Apply a separate linear transformation to each of the keys, queries, and values.\n",
    "2. Compute attention scores for each of the values.\n",
    "3. For each query, compute an attention-weighted sum of the values.\n",
    "4. Apply a residual connection and layer normalization."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start off with the linear transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "inner_dim = 17 # this will be the internal dimension\n",
    "linear_k = nn.Linear(embedding_dim, inner_dim)\n",
    "linear_v = nn.Linear(embedding_dim, inner_dim)\n",
    "linear_q = nn.Linear(embedding_dim, inner_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The memory is concatenated across the sequence dimension and fed as keys/values. Be careful, as it's not concatenated with the queries. This is because each query represents one word we want to predict, so we can't modify the number of queries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_embs_w_memory = torch.cat([memory, word_embs], dim=0)\n",
    "k_tfmd = linear_k(word_embs_w_memory)\n",
    "v_tfmd = linear_v(word_embs_w_memory)\n",
    "q_tfmd = linear_q(word_embs) # No memory for the queries"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we compute scaled dot product attention as per the usual Transformer. Scaled dot product attention computes the attention score as the dot product between the query and key vectors. To prevent the values from exploding as the dimensionality of the vectors increases, we divide the raw attention score by the sqrt of the embedding size.\n",
    "\n",
    "$$ \\textrm{Attention}(Q, K, V) = \\textrm{softmax}(\\frac{QK^T}{\\sqrt{d_k}})V $$\n",
    "\n",
    "![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/scaled_dot_product_attention.png?zoom=2&w=750)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using einsum notation here to make the code easy to read: if you're not familiar with einsum, check out [this awesome tutorial](https://rockt.github.io/2018/04/30/einsum). In short, einsum denotes the shape of the inputs and outputs using one letter to represent each dimension. Below, the inputs are shaped `(i, b, d)` and `(j, b, d)` and the output is shaped `(i, j, b)` where the same letter represents the same size. Einsum is computed by taking the dot product across dimensions with the same character."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "content_attn = torch.einsum(\"ibd,jbd->ijb\", q_tfmd, k_tfmd) / (embedding_dim ** 0.5) # scale"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice we're not yet applying the softmax activation. This is because we need a couple more pieces to get the full attention score. The first of these is the relative positional embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Relative positional encodings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One of the key ideas in the Transformer XL is the idea of relative positional encodings. Instead of having a single embedding represent each **absolute** position, the Transformer XL computes an embedding that represents the **distance** between any two tokens. This is used to compute the attention between the two words."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The authors use the following equation to compute the attention between a query vector $ q_i $ and key vector $k_j$:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{align}\n",
    "A^{rel}_{i,j} = \n",
    "    \\underbrace{E_{x_i}^TW_q^TW_{k,E}E_{x_j}}_{(a)}\n",
    "    + \\underbrace{E_{x_i}^TW_q^TW_{k,R} \\color{green}R_\\color{green}{i-j} }_{(b)}\n",
    "    \\\\ \n",
    "    + \\underbrace{ \\color{red}u^\\color{red}T W_{k,E}E_{x_j}}_{(c)} \n",
    "    + \\underbrace{ \\color{red}v^\\color{red}T W_{k,R} \\color{green}R_\\color{green}{i-j}}_{(d)}\n",
    "\\end{align}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, $E_{x}$ is the embedding for $ x $ and $ W $ are all transformation matrices. The (a) term is the content-based attention terms that we already computed above. (b) and (d) are based on the relative positional embeddings and are dependent on the distance between $q_i$ and $k_j$. $u$ and $v$ are global bias terms that represent biases for certain content and certain positions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's move on to the detailed implementation of terms (b) to (d). We'll first add the content bias (term (c) in the above equation) since it is the most simple to compute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "u = torch.rand(17).expand_as(q_tfmd)\n",
    "content_attn = content_attn + torch.einsum(\"ibd,jbd->ijb\", u, k_tfmd) / (embedding_dim ** 0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we compute the relative positional embeddings necessary for the positional attention terms. For the relative positional embeddings, the Transformer XL uses fixed sinusoidal embeddings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([12., 11., 10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  0.])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos_idxs = torch.arange(seq + prev_seq - 1, -1, -1.0, dtype=torch.float)\n",
    "pos_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3XmclNWd7/HPr6p6B3qBRoRuunGDKLI2yqLJaEyuEzU6MyZqxGVixkmiZpnJzU0mmbl3JrkzmVdyk2h2r5poRIwxGp3EiZqYjFHZWkBlM6DQ7NDQ0CxNr/WbP55qaKAR6HqK2r7v16teVfXUwzk/UL79cOo855i7IyIi2S+S7gJERCQcCnQRkRyhQBcRyREKdBGRHKFAFxHJEQp0EZEcoUAXEckRCnQRkRyhQBcRyRGxU9nZsGHDvL6+/lR2KSKS9V599dUd7l59vPNOaaDX19fT2Nh4KrsUEcl6ZtZ0IudpyEVEJEco0EVEcoQCXUQkRyjQRURyhAJdRCRHHDfQzewBM9tuZsv6HPu6ma0ys9fN7Ekzq0htmSIicjwncoX+E+DyI449D4x39wnAn4AvhlyXiIicpOMGuru/CLQccew5d+9OvJ0P1KSgtoN+/+Z2vv+HNansQkQk64Uxhv5R4D+P9aGZ3W5mjWbW2NzcPKAOXlmzg28/v5r2rp6B1igikvOSCnQz+xLQDcw51jnufq+7N7h7Q3X1ce9c7de0+io6e+Is29Q6wEpFRHLfgAPdzG4FrgRudHcPraJ+TK2rBGDRul2p7EZEJKsNKNDN7HLg88AH3b0t3JKONnRQEWdWl9G4ruX4J4uI5KkTmbY4F5gHjDWzjWZ2G/BdYDDwvJktNbMfprhOptVX0di0i3g8pf8YEBHJWsddbdHdb+jn8P0pqOUdNdRX8eiiDaxp3sc5pw0+1d2LiGS8rLlTdFp97zi6hl1ERPqTNYE+uqqU6sFFNOqLURGRfmVNoJsZ0+ordYUuInIMWRPoAA11VWzcdYAtrQfSXYqISMbJqkCfVl8FoGEXEZF+ZFWgv+v0wZQWRjUfXUSkH1kV6LFohCmjK3XHqIhIP7Iq0AEa6itZuXUPe9q70l2KiEhGybpAn1ZfhTssbtJVuohIX1kX6JNqK4hGTF+MiogcIesCvawoxnkjh2g+uojIEbIu0CGYj750w246u+PpLkVEJGNkZaBPq6+kozvOss3a8EJEpFdWBvrUxEJdmo8uInJIVgb68MHF1A8t1Xx0EZE+sjLQIVgfvXFdCyne/U5EJGtkbaBPq69kV1sXbzXvT3cpIiIZIWsDveHgQl0aRxcRgSwO9DOGlTG0rFDj6CIiCVkb6GZGQ30ljU26QhcRgSwOdAjWdWna2cb2Pe3pLkVEJO2yOtAPjqNroS4RkewO9PNGDqG4IKJ1XUREyPJAL4hGmFxbqZUXRUQ4gUA3swfMbLuZLetzrMrMnjez1YnnytSWeWzT6itZvrmVfR3d6SpBRCQjnMgV+k+Ay4849gXgd+5+NvC7xPu0aKivIu6wdP3udJUgIpIRjhvo7v4icOQg9dXAg4nXDwLXhFzXCZs8uoKIoXF0Ecl7Ax1DP83dtyRebwVOO9aJZna7mTWaWWNzc/MAuzu2wcUFvOv0IZqPLiJ5L+kvRT1YHeuYK2S5+73u3uDuDdXV1cl2169p9VUsWb+brh5teCEi+Wuggb7NzE4HSDxvD6+kk9dQX0lbZw8rt+xJZxkiImk10EB/Grgl8foW4KlwyhmYhrrgBiOt6yIi+exEpi3OBeYBY81so5ndBnwNeJ+ZrQYuS7xPmxHlxdRWlbBorcbRRSR/xY53grvfcIyP3htyLUmZVlfFi6ubcXfMLN3liIiccll9p2hfDfVV7NjXybqdbekuRUQkLXIm0KclNo7WfHQRyVc5E+hnVg+iorRAOxiJSN7KmUCPRIyGOi3UJSL5K2cCHYJx9Ld37GfHvo50lyIicsrlVKD3jqPrKl1E8lFOBfr4UeUUxiIaRxeRvJRTgV4UizKppoJF2pJORPJQTgU6wLQxlSzf1Epbpza8EJH8knOB3lBfRXfcWbpBG16ISH7JuUCfMroSM30xKiL5J+cCvbykgLGnDdYdoyKSd3Iu0CHY8GJx0y66teGFiOSRnAz0hvpK9nf2sGrr3nSXIiJyyuRkoE+rDza80Hx0EcknORnoIytKGFVRovnoIpJXcjLQIRh2aVzXQrCHtYhI7svhQK9i254ONu46kO5SREROiZwNdG14ISL5JmcD/ZzhgxlcHGORbjASkTyRs4F+aMMLXaGLSH7I2UCHYBx99fZ97Nrfme5SRERSLqcD/eB8dE1fFJE8kFSgm9lnzWy5mS0zs7lmVhxWYWGYUFNOYVQbXohIfhhwoJvZKOBTQIO7jweiwPVhFRaG4oIo59eUa6aLiOSFZIdcYkCJmcWAUmBz8iWFq6G+kjc2tdLe1ZPuUkREUmrAge7um4BvAOuBLUCruz8XVmFhmVZXRVeP85o2vBCRHJfMkEslcDUwBhgJlJnZ7H7Ou93MGs2ssbm5eeCVDtDUuuAGI30xKiK5Lpkhl8uAte7e7O5dwBPAzCNPcvd73b3B3Ruqq6uT6G5gKssKOXv4II2ji0jOSybQ1wPTzazUzAx4L7AynLLC1VBfxatNu+iJa6EuEcldyYyhLwAeBxYDbyTaujekukI1rb6Sve3d/GmbNrwQkdyV1CwXd//f7j7O3ce7+03u3hFWYWHShhcikg9y+k7RXjWVJYwYUqyFukQkp+VFoJvZwQ0vRERyVV4EOgTDLptb29m0WxteiEhuyptAb0hseKGrdBHJVXkT6ONGDGFQUUzz0UUkZ+VNoEcjxpS6Shr1xaiI5Ki8CXSAaXWVvLltL61tXekuRUQkdHkV6A31VbjD4vW6SheR3JNXgT6ptoJYxDSOLiI5Ka8CvaQwyvhR5RpHF5GclFeBDsG6Lks37qajWxteiEhuybtAb6ivorM7zrJNrekuRUQkVPkX6IkNLxau1bCLiOSWvAv0oYOKOKO6THeMikjOybtAh2Cf0camXcS14YWI5JC8DPSG+kpaD3SxYsuedJciIhKavAz0S8cNpzAa4bHGDekuRUQkNHkZ6EMHFXHFhNN5YvEm9nV0p7scEZFQ5GWgA9w0o459Hd08uXhjuksREQlF3gb65NoKxo8awkPzmnDXl6Mikv3yNtDNjJun17N6+z4WrNUURhHJfnkb6ABXTRxJeUkBP53XlO5SRESSlteBXlIY5UNTa3h2+Va27WlPdzkiIknJ60AHmD29ju6488iC9ekuRUQkKUkFuplVmNnjZrbKzFaa2YywCjtV6oeV8Z5zqpm7cD1dPfF0lyMiMmDJXqHfDfzG3ccBE4GVyZd06t08o47tezt4bvm2dJciIjJgAw50MysH3g3cD+Dune6+O6zCTqU/GzucURUlPDRvXbpLEREZsGSu0McAzcCPzWyJmd1nZmUh1XW4pY/AU3empGmAaMSYPb2OBWtbeHPr3pT1IyKSSskEegyYAvzA3ScD+4EvHHmSmd1uZo1m1tjc3DywnvZuhSU/hS2vJVHuO7tuWi2FsQgPz9cURhHJTskE+kZgo7svSLx/nCDgD+Pu97p7g7s3VFdXD6ynabdB0RB46dsDLvZ4qsoKuXLC6TyxeCN727tS1o+ISKoMONDdfSuwwczGJg69F1gRSlVHKi6Hho/Cil/CzrdS0gXAzTPq2d/Zw5NLNqWsDxGRVEl2lstdwBwzex2YBPxr8iUdw/RPQqQAXvlOyrqYWFPO+aPKtb6LiGSlpALd3ZcmhlMmuPs17p66jToHnwaTb4Slc4Ix9RQwM26aUcea7fuY/7bWdxGR7JJdd4rOvAvi3TD/Bynr4oMTR1JRWsBP569LWR8iIqmQXYFedQac9xfQ+AC0t6aki+KCKB9uqOXZ5dvY2qr1XUQke2RXoAPM+gx07IFF96esixsvHE3cnUcWan0XEcke2Rfop0+Asy6D+d+HrgMp6aJuqNZ3EZHsk32BDnDRZ2F/c/AFaYrcPKOO5r0dPLs8NV/AioiELTsDvW4W1EyDl++BntRs8vyec4ZTW1XCQ9r8QkSyRHYGullwlb67KbjZKAWiEWP2hXUsXNvCqq17UtKHiEiYsjPQAc75cxg2Fl76FqToJqAPNWh9FxHJHtkb6JEIXPQZ2LYM1vw2JV1UlRVy1YSRPLl4k9Z3EZGMl72BDjD+WhhSE1ylp8jNM+rY39nDE4u1vouIZLbsDvRYIcy8E5pehvULjn/+AEysrWBiTTk/na/1XUQks2V3oANMuRlKquDl1C2tO3t6sL7LvLd3pqwPEZFkZX+gF5bBhR+HN5+B7anZ0vSq3vVdNIVRRDJY9gc6wAV/AwVl8PLdKWm+uCDKdQ21PLdiG1taU3N3qohIsnIj0EurYOqt8MbPYXdq1l+ZPb2OuDtzF2h9FxHJTLkR6AAz7gAMXvluSpqvrSrlkrHDmbtoA53dWt9FRDJP7gR6+SiYcB0sfgj270hJFzdN1/ouIpK5cifQAWZ9CrrbYcGPUtL8e86pZnRVqb4cFZGMlFuBXj0Wxl0BC++Fjr2hNx+JGLOnj2bhOq3vIiKZJ7cCHYJFu9p3w6sPpqT5D02tpSgW0VW6iGSc3Av0mgaovxjmfRe6O0JvvrKskKsmjuTJJZvYo/VdRCSD5F6gQ3CVvncLvP5YSpq/eUYdbZ09PPHqxpS0LyIyELkZ6GdeCiMmBDcaxXtCb35CTQUTayu0vouIZJTcDPTeDTB2roZVv05JFzdPr+Ot5v3Me0vru4hIZkg60M0samZLzOxXYRQUmnOvhsoxKdsA44oJp1NZWqAt6kQkY4Rxhf5pIDWrYiUjEoVZn4bNi2Hti6E3X1wQ5cPTanl+pdZ3EZHMkFSgm1kNcAVwXzjlhGziDTDoNHjpmylpfvaFWt9FRDJHslfo3wY+D2Tm4iYFxcEaL2//ATYtDr352qpSLh07nEcWan0XEUm/AQe6mV0JbHf3V49z3u1m1mhmjc3NzQPtbuCm/jUUladsA4ybZtSxY18Hv9H6LiKSZslcoc8CPmhm64BHgUvN7OEjT3L3e929wd0bqqurk+hugIqHwAUfgxVPw441oTf/7rOrqRtayk/nrQu9bRGRkzHgQHf3L7p7jbvXA9cDL7j77NAqC9OFH4dYEbwS/gYYkYgx+8I6Fq3bxcotWt9FRNInN+ehH2nQcJg8G5bOhT2bQ2/+Qw01lBZG+eqvV9AT141GIpIeoQS6u//B3a8Mo62UmXkXeBzmfz/0pitKC/k/V53Hy2t28oM/hD+sIyJyIvLjCh2gsh7G/yU0/hgO7Aq9+Q811HD1pJF88/k/sXBtS+jti4gcT/4EOsCsz0DnPlgU/rR5M+P//sX5jK4q5VNzl9CyvzP0PkRE3kl+BfqI8XD2+2H+D6GzLfTmBxXF+O5HptCyv5PP/fw14hpPF5FTKL8CHYJFu9p2wNI5KWl+/Khyvnzlu3hh1Xbuf2ltSvoQEelP/gX66BlQeyG8fA/0pGaDipum13H5eSP499+sYsn68MfrRUT6k3+BbgYXfw5a18Nz/5iiLox/v3YCI8qLuWvuEloPaGcjEUm9/At0gHPeD9M/CQt+EMx6SYHykgK+c8Nktra284VfvK6NMEQk5fIz0AHe9xU46zJ45nMpWV4XYPLoSj5/+Vj+c9lWHp6vddNFJLXyN9CjMbj2Aag6E352E+x8KyXdfOyiM7hkbDVf+dVKlm9uTUkfIiKQz4EOUFwOH3k0GFefez0c2B16F5GI8f8+PInKsgLufGQJ+zq6Q+9DRATyPdABqs6A6x6Glrfh8b+GnvADt6qskHuun0zTzv186ck3NJ4uIimhQAeovwiu+Ca89QI896WUdHHhGUP57GXn8NTSzfy8cWNK+hCR/KZA7zX1Fph+Byz4ITQ+kJIuPnnJWcw6ayj/9PQy/rRtb0r6EJH8pUDv6/1fCZYGeOZ/wtv/FXrz0YjxresmMagoxp2PLOZAZ0/ofYhI/lKg9xWJwl/dD0PPgsduTsnMl+GDi/nWdZNYvX0f//wfy0NvX0TylwL9SMVD4IZHwSLwyHUpmfly8dnVfPLPzuTRRRt4aumm0NsXkfykQO9P1Zhg5suudSmb+fLZy86hoa6Sf3jiDdbu2B96+yKSfxTox1I/C65MzHx59h9Cbz4WjXDPDZMpiEW4Y85i2rs0ni4iyVGgv5MpN8OMO2Hhj2DR/aE3P7KihG9cO5EVW/bwb8+sDL19EckvCvTjed+/pHTmy2XnnsZtF43hwXlN/GbZltDbF5H8oUA/nt6ZL8POSdnMl/91+Tgm1JTz+cdfZ0NL+DspiUh+UKCfiOIhcMPcINwf+XDom0wXxiJ894YpuMNdc5fQ1RMPtX0RyQ8K9BN1cOZLE/z81tBnvoweWsrX/moCSzfs5hvPvhlq2yKSHxToJ6NuJlz5LXj7D/DsF0Nv/ooJpzN7+mh+9OLb/H7V9tDbF5HcNuBAN7NaM/u9ma0ws+Vm9ukwC8tYU25KzHy5FxbdF3rzX77iXMaNGMzfPbaULa0HQm9fRHJXMlfo3cDfu/u5wHTgDjM7N5yyMtz7/gXO/h/wzOfhrd+H2nRxQZTv3TiFju44t/2kUV+SisgJG3Cgu/sWd1+ceL0XWAmMCquwjBaJwl/dB9Vj4ee3wI41oTZ/ZvUgvnfjFDbsauOKe/7I8yu2hdq+iOSmUMbQzawemAwsCKO9rNC75kskBnOvC33myyVjh/Pruy5m9NBS/uahRv7tmZWa/SIi7yjpQDezQcAvgM+4+55+Pr/dzBrNrLG5uTnZ7jJLZR1cN6fPzJeuUJsfPbSUxz8+8+AXpTfcO1/j6iJyTEkFupkVEIT5HHd/or9z3P1ed29w94bq6upkustMdTPgqruDmS9zPgSt4e5GVFwQ5avXnM/d109ixZY9XHHPS7z4pxz7wSgioUhmlosB9wMr3f2b4ZWUhSbfGIT6hoXw/RmwZA6EvG/o1ZNG8fSdF1E9qIhbfryQbz73Jj1x7U0qIockc4U+C7gJuNTMliYeHwipruwz9Vb4xMsw4nx46pMw93rYuzXULs4aPohf3jGLa6fUcM8La7jp/gU07+0ItQ8RyV52Knegb2ho8MbGxlPWX1rE48HqjL/9Z4gVwQe+AedfC2ahdvNY4wb+6allDC4u4Ds3TGb6GUNDbV9EMoeZveruDcc7T3eKhi0SgemfgI+/FCzo9cTH4GezYV+4494fbqjll3fMYnBxjI/8//l87/driGsIRiSvKdBTZdhZ8NHfBDchrX4evn8hLH8y1C7GjRjC03dexBUTRvL1Z9/kow8uYtf+zlD7EJHsoUBPpUgUZn0a/vZFqKgLpjb+/K+hrSW0LgYVxbjn+kl85ZrxvLJmJ1fc80debQp3TryIZAcF+qkwfBzc9jxc+o+w8j/gexfCql+H1ryZcdP0On7xiZlEo8Z1P5rHfX98m1P5/YiIpJ8C/VSJxuDdn4Pb/wCDT4NHPwJP/G2od5ieX1POr+66mEvHDeerv17Jxx9+ldYD4d7sJCKZS4F+qo0YDx97Ad7zBVj2eDBvffXzoTVfXlLAj26aypeveBe/W7mdq77zEss2tYbWvohkLgV6OsQK4ZIvwsd+ByWVMOdaeOpOaD9q5YQBMTM+dvEZ/OxvZ9DVE+cvv/8KD89v0hCMSI5ToKfTyEnBEMxFfwdL5wRX6yEuxzu1rpJff+piZp41lC//chl3PrKENdv3hta+iGQW3ViUKTY2wpMfh52roeG2YLpj0aBQmo7HnR/811vc/dvVdPbEufjsYdwyo55Lxg0nGgn3hicRCd+J3likQM8kXQfgha/CvO9BxWj4wNfhrMuC6Y8h2LGvg0cXrufh+evZuqed2qoSbp5ez4cbaikvLQilDxEJnwI9m62fD7/8BLS8DUNGwcQbYNJHYOiZoTTf1RPnueXbePCVdSxc10JJQZRrJo/i1pn1jB0xOJQ+RCQ8CvRs190RzFVfOgfeegE8DqNnBis7nntNaMMxyze38tArTfxy6SY6uuNMP6OKW2eO4bJ3DScW1VcsIplAgZ5L9myG1+YGy/K2vAUFZXDeNTDpRqibGcrCX7v2d/Loog08PL+JTbsPMKqihNnT67h+Wi2VZYUh/CZEZKAU6LnIHTYsgCUPB+vCdO6DyjHBVfvEG6C8Jukuunvi/Hbldh58ZR3z3t5JUSzC1ZNGcsvMes4bWR7Cb0JETpYCPdd17ocVTwdDMuv+CBiceUlw1T7uSigoTrqLVVv38OArTTy5ZCPtXXGm1Vdy68wxvP+80yjQcIzIKaNAzycta4MhmaWPQOsGKC6H8dcGV+4jpyQ9JNPa1sVjjRt4aP46NrQcYMSQYmZPH831F4xm2KCikH4TInIsCvR8FI/DuheDsfaVT0N3O1S/Kwj2CdfBoOFJNd8Td15YFQzHvLRmB4XRCFPqKrigvooLxgxl8ugKyopiIf1mRKSXAj3ftbfCsieCIZmNiyASgzMvhdHTYVQDjJwMxUMG3Pya7Xv52aINzH+7heWbW4k7RCPG+JFDuGBMFdPqg4e+UBVJngJdDml+Mwj2lb8KZskAYFA9Ngj3mqkwaioMPy9YFfIk7W3vYvH63Sxcu5NFa3exdONuOrvjAJxz2iCm1VdxwZjgcXp5SYi/MZH8oECX/rW1wObFsGlxsNzApkZo2xl8FiuB0ydCTUMQ8KOmBnesnuQYfHtXD69vbGXRuhYWrm3h1aZd7OvoBqC2qiQI+ETIjxlWhoW836pIrlGgy4lxh91NiXB/NXhseS0Yfwcoq06Ee+JKfuQUKKk4qS66e+Ks2rqXBWtbWLS2hUXrWtiZ2Cpv2KAiLhhTeXCIZtyIwbqhSeQICnQZuJ4u2LY8uHrfmAj5HW8e+nzo2Yeu4IeeAeWjgznwhaUn1Ly781bzfhatCwJ+wdoWNu0+AATj8KMqSqgbWkptVSl1VaWMripl9NDgeXCx1pyR/KNAl3C1t8LmJYeu5Dc2wv7th59TOhTKa6GiNng+8nVp1TGHbzbvPsCidS2s3raPppY21u/cz/qWNna1Hb7jUlVZYb9BXze0lNMGFxPR6pGSgxToklrusHcL7GoK5r7vXh88t26E3RuC111th/+agtJEuNf0H/qDTz/qS9k97V2s39nG+pbg0bSzjQ0tbTS17Gfz7nZ64of+/y2MRaitLEkEfBm1VaWMqiihqqyQytICKssKqSgp0JCOZJ0TDfSkJg2b2eXA3UAUuM/dv5ZMe5JFzGDIyODBjKM/dw/2Sz0q6NcHr7e8Bm07jmgzGlzll1RAcQWUVDKkpILxxRWMLwnec0YFnBd83lU0gm2dxaxrK2Td7p4g6BPhv2jdoS9ijzS4OEZVWSEVpYVUlRZQWZp4XVaQeC6kInG893VRLJwljEVSacCBbmZR4HvA+4CNwCIze9rdV4RVnGQxs2CIpbQq2JmpP51tsGfTodDfvSEI+QO7gx8G+7ZC86rgfcfR+6IWADWJx0Wx4oM/BBhcgVeX01kwhL2UccAL2e8F7OspZE9PjD3dMXZ1xtjVFWXnrig7tkRpao/Q0hWj3Qs5QCEHKKKDAiAYwikrjFJRWkhlWQGlhTFKC6OUFEQpKYxSWhiltDBGSUHv6yglvecURiktSHxe2PfzKIXRiGb4SKiSuUK/AFjj7m8DmNmjwNWAAl1OTGEpDDs7eBxPvCcYx29PhP2B3f28Trxvb8VaN1HUvpyi9tZg45B41zu3HwH6WcWgO1JMV6SITiumI15I+54COiigy6N0eJROj9IRj9AeD153EqPbo3QQYx9RuojRzaHjXcToShyPWwyLFhCJxiASxSIxLBIlEokSiUYhWpB4HSMaDZ4Pvo/1Hk88YjFi0SjRaIxIrICIWXC+GZFohEgkSjTSeyxCJBolGokQjUawSJRYNJo4fuiz3mPRaIRoJNir1iBo0wwziESMiHHofeKziAXn933ue45xqD1LnCPJSybQRwEb+rzfCFyYXDkixxCJHrriH4iebug+EIT7wUdb8HzU8cRn3e3EutqIdR2gpO9n8S7o6Qza7OlMvG/He7rwnk68u6vP8S4s3k0k3nns2uKJRwaLu+GAY4lH8Br6P973dRzo7udzgEPfgPS+P3TcCQLfj/isl2NgRx8/+tx+ft1RDD/q8PHa7T3We/aR/Ryu7f3f4NwZf95P3+FJ+cIbZnY7cDvA6NGjU92dSP+iMYgOhqLU7chk9BcBCe7BvzKO/GHQ0wneE6zD4z3BOb3PfV8f9nz0uR7voaenO3h09xD3OB6PE4/3EHfHe3qPOfF4T+KzOO5xPH7ofI/7wWPufujZ47h78PvAD712xw++jx98j/th5x9833tu759J4jl45YkUjB/25+Y45hxs9+BHieO95x8MVj/0eb//HTh4Qt8P+jnvqF98jPb6c/TxIWUnd//GQCQT6JuA2j7vaxLHDuPu9wL3QjDLJYn+RLKXWeKHSgwKwl/+wAj+MmtptPyWzPytRcDZZjbGzAqB64GnwylLRERO1oB/oLt7t5ndCTxLMG3xAXdfHlplIiJyUpL6F5q7PwM8E1ItIiKSBN0yJyKSIxToIiI5QoEuIpIjFOgiIjlCgS4ikiNO6fK5ZtYMNA3wlw8Ddhz3rPTK9BozvT7I/BozvT5QjWHItPrq3L36eCed0kBPhpk1nsh6wOmU6TVmen2Q+TVmen2gGsOQ6fUdi4ZcRERyhAJdRCRHZFOg35vuAk5ApteY6fVB5teY6fWBagxDptfXr6wZQxcRkXeWTVfoIiLyDrIi0M3scjN708zWmNkX0l1PX2ZWa2a/N7MVZrbczD6d7pqOxcyiZrbEzH6V7lqOZGYVZva4ma0ys5Vm1s/O0+llZp9N/DdeZmZzzaw4A2p6wMy2m9myPseqzOx5M1udeK7MsPq+nvjv/LqZPWlmqd/54SRr7PPZ35uZm9mwdNR2sjI+0PtsRv3nwLnADWZ2bnqrOkw38Pfufi4wHbgjw+rr69PAynQXcQx3A79x93HARDKsTjMbBXwKaHD38QRLRl+f3qoA+AnFhMnkAAAC3klEQVRw+RHHvgD8zt3PBn6XeJ8uP+Ho+p4Hxrv7BOBPwBdPdVFH+AlH14iZ1QLvB9af6oIGKuMDnT6bUbt7J9C7GXVGcPct7r448XovQRCNSm9VRzOzGuAK4L5013IkMysH3g3cD+Dune6+O71V9SsGlJhZDCgFNqe5Htz9RaDliMNXAw8mXj8IXHNKi+qjv/rc/Tl37068nU+w21naHOPPEOBbwOfpd++5zJQNgd7fZtQZF5gAZlYPTAYWpLeSfn2b4H/OTNyOeAzQDPw4MSR0n5mVpbuovtx9E/ANgqu1LUCruz+X3qqO6TR335J4vRU4LZ3FHMdHgf9MdxFHMrOrgU3u/lq6azkZ2RDoWcHMBgG/AD7j7nvSXU9fZnYlsN3dX013LccQA6YAP3D3ycB+0jtMcJTEOPTVBD98RgJlZjY7vVUdn3vfXZEzi5l9iWDIck66a+nLzEqBfwD+Kd21nKxsCPQT2ow6ncysgCDM57j7E+mupx+zgA+a2TqCIatLzezh9JZ0mI3ARnfv/ZfN4wQBn0kuA9a6e7O7dwFPADPTXNOxbDOz0wESz9vTXM9RzOxW4ErgRs+8udNnEvzgfi3xd6YGWGxmI9Ja1QnIhkDP6M2ozcwIxn5Xuvs3011Pf9z9i+5e4+71BH9+L7h7xlxduvtWYIOZjU0cei+wIo0l9Wc9MN3MShP/zd9Lhn1x28fTwC2J17cAT6WxlqOY2eUEw38fdPe2dNdzJHd/w92Hu3t94u/MRmBK4v/TjJbxgZ748qR3M+qVwGMZthn1LOAmgqvepYnHB9JdVBa6C5hjZq8Dk4B/TXM9h0n86+FxYDHwBsHfnbTfTWhmc4F5wFgz22hmtwFfA95nZqsJ/mXxtQyr77vAYOD5xN+XH6arvneoMSvpTlERkRyR8VfoIiJyYhToIiI5QoEuIpIjFOgiIjlCgS4ikiMU6CIiOUKBLiKSIxToIiI54r8BRns7vlWpwL4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim))\n",
    "sinusoid_inp = torch.einsum(\"i,j->ij\", pos_idxs, inv_freq)\n",
    "plt.plot(sinusoid_inp[0, :].detach().numpy());\n",
    "plt.plot(sinusoid_inp[6, :].detach().numpy());"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([13, 1, 32])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relative_positional_embeddings = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)[:,None,:]\n",
    "relative_positional_embeddings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can gather this into its own class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEmbedding(nn.Module):\n",
    "    def __init__(self, d):\n",
    "        super().__init__()\n",
    "        self.d = d\n",
    "        inv_freq = 1 / (10000 ** (torch.arange(0.0, d, 2.0) / d))\n",
    "        # register buffer tells pytorch that this tensor is part of the modle\n",
    "        # this means that it will be saved in the state_dict and moved to the GPU\n",
    "        # along with the model\n",
    "        self.register_buffer(\"inv_freq\", inv_freq)\n",
    "        \n",
    "    def forward(self, positions: torch.LongTensor, # (seq, )\n",
    "               ):\n",
    "        # outer product\n",
    "        sinusoid_inp = torch.einsum(\"i,j->ij\", positions.float(), self.inv_freq)\n",
    "        pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)\n",
    "        return pos_emb[:,None,:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also apply transformations to the positional embeddings separate from the values/keys."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_p = nn.Linear(embedding_dim, inner_dim)\n",
    "pos_tfmd = linear_p(relative_positional_embeddings)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time, we'll be adding the positional bias during attention computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 13, 3])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v = torch.rand(17) # positional bias\n",
    "pos_attn = torch.einsum(\"ibd,jd->ijb\", q_tfmd + v, pos_tfmd[:,0,:]) / (embedding_dim ** 0.5) # scale\n",
    "pos_attn.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we compute a relative postional embedding for each key-query pair, a naive implementation of attention using relative positional embeddings would be $O(n^2)$ in terms of computational complexity. Luckily, the authors proposed a trick to reduce this to $O(n)$ time by computing the attention for one query then shifting the embeddings for different query positions.."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_pad = torch.zeros((seq, 1, batch_size), dtype=torch.float)\n",
    "# this padding + shifting efficiently computes the attention for all \n",
    "pos_attn = (torch.cat([zero_pad, pos_attn], dim=1)\n",
    "                    .view(seq + prev_seq + 1, seq, batch_size)[1:]\n",
    "                    .view_as(pos_attn)) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The attention is computed as the sum of content and positional attention."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_attn = content_attn + pos_attn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we do language modeling, we need to prevent the model from being able to look at the word that it is supposed to be predicting. In the Transformer, we achieve this by setting the attention score to zero. This masks out words that we don't want the model to be able to see. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = torch.triu(\n",
    "    torch.ones((seq, seq + prev_seq)),\n",
    "    diagonal=1 + prev_seq,\n",
    ").byte()[...,None]\n",
    "raw_attn = raw_attn.masked_fill(mask, -float('inf'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now compute the outputs as the weighted sum of the value vectors using the attention scores."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 3, 17])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attn = torch.softmax(raw_attn, dim=1)\n",
    "attn_weighted_sum = torch.einsum(\"ijb,jbd->ibd\", attn, v_tfmd)\n",
    "attn_weighted_sum.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we project the attention weighted sums back to their original dimension and apply a residual connection and layer normalization. We apply layer normalization after the residual connection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 3, 32])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear_out = nn.Linear(inner_dim, embedding_dim)\n",
    "layer_norm = nn.LayerNorm(embedding_dim)\n",
    "output = layer_norm(word_embs + linear_out(attn_weighted_sum))\n",
    "output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MultiHeadAttention (MHA): The core component"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Aggregating all the above and applying a couple of optimizations by grouping some computations together as well as adding dropout, we get the following MultiHeadAttention module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self, d_input: int, d_inner: int, n_heads: int=4, \n",
    "                 dropout: float=0.1, dropouta: float=0.):\n",
    "        super().__init__()\n",
    "        self.d_input = d_input\n",
    "        self.d_inner = d_inner\n",
    "        self.n_heads = n_heads\n",
    "        # this layer applies the linear transformation required\n",
    "        # for the keys and values for all heads at once for efficiency\n",
    "        self.linear_kv = nn.Linear(\n",
    "            d_input, \n",
    "            (d_inner * n_heads * 2), # 2 is for keys and values\n",
    "            bias=False, # we don't apply bias, making this a simple matrix multiplication\n",
    "        )\n",
    "        # for queries (will not be concatenated with memorized states so separate)\n",
    "        self.linear_q = nn.Linear(\n",
    "            d_input, d_inner * n_heads,\n",
    "            bias=False\n",
    "        )\n",
    "        # for positional embeddings\n",
    "        self.linear_p = nn.Linear(\n",
    "            d_input, d_inner * n_heads,\n",
    "            bias=False\n",
    "        )\n",
    "        self.scale = 1 / (d_inner ** 0.5) # for scaled dot product attention\n",
    "        self.dropa = nn.Dropout(dropouta)\n",
    "        # we will use this to project back to the input dimension\n",
    "        self.lout = nn.Linear(self.d_inner * self.n_heads, self.d_input, bias=False)\n",
    "        self.norm = nn.LayerNorm(self.d_input)\n",
    "        self.dropo = nn.Dropout(dropout)\n",
    "        \n",
    "    def _rel_shift(self, x):\n",
    "        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),\n",
    "                               device=x.device, dtype=x.dtype)\n",
    "        return (torch.cat([zero_pad, x], dim=1)\n",
    "                    .view(x.size(1) + 1, x.size(0), *x.size()[2:])[1:]\n",
    "                    .view_as(x)) \n",
    "        \n",
    "    def forward(self, input_: torch.FloatTensor, # (cur_seq, b, d_in)\n",
    "                pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_in)\n",
    "                memory: torch.FloatTensor, # (prev_seq, b, d_in)\n",
    "                u: torch.FloatTensor, # (H, d)\n",
    "                v: torch.FloatTensor, # (H, d)\n",
    "                mask: Optional[torch.FloatTensor]=None,\n",
    "        ):\n",
    "        \"\"\"\n",
    "        pos_embs: we pass the positional embeddings in separately\n",
    "            because we need to handle relative positions\n",
    "        input shape: (seq, bs, self.d_input)\n",
    "        pos_embs shape: (seq + prev_seq, bs, self.d_input)\n",
    "        output shape: (seq, bs, self.d_input)\n",
    "        \"\"\"\n",
    "        cur_seq = input_.shape[0] #  sequence length of current segment\n",
    "        prev_seq = memory.shape[0] # sequence length of previous segment\n",
    "        H, d = self.n_heads, self.d_inner\n",
    "        input_with_memory = torch.cat([memory, input_], dim=0) # concatenate recurrent memory\n",
    "                                                               # across sequence dimension\n",
    "\n",
    "        # we will use the following symbols to represent the shape of the tensors\n",
    "        # cs: current sequence length, b: batch, H: number of heads\n",
    "        # d: inner dimension, ps: previous sequence length\n",
    "        # The key and value are now conditioned on the preceding context\n",
    "        k_tfmd, v_tfmd = \\\n",
    "            torch.chunk(self.linear_kv(input_with_memory), 2, dim=-1) # (cs + ps, b, H * d)\n",
    "        q_tfmd = self.linear_q(input_) # (cs, b, H * d)\n",
    "\n",
    "        # apply scaled dot product attention\n",
    "        # look at the following dimensions carefully, since this is the key operation\n",
    "        # in the Transformer/Transformer XL architecture\n",
    "        \n",
    "        _, bs, _ = q_tfmd.shape\n",
    "        assert bs == k_tfmd.shape[1]\n",
    "        # content-based attention term ((a) + (c) in the paper)\n",
    "        # this is the standard attention term in the original Transformer, except without positional embeddings\n",
    "        # which are handled separately in the Transformer XL (see below)\n",
    "        # here, i corresponds to the number of queries = number of current inputs/targets (seq-wise)\n",
    "        # j corresponds to the number of key/values = number of vectors that we can use to compute the \n",
    "        # vector for each query\n",
    "        content_attn = torch.einsum(\"ibhd,jbhd->ijbh\", (\n",
    "                (q_tfmd.view(cur_seq, bs, H, d) + # (a)\n",
    "                 u), # (c): u represents the global (independent of the query)\n",
    "                     # bias towards certain key/values = words\n",
    "                     # Note: maybe this could be a per-attention head parameter?\n",
    "                 k_tfmd.view(cur_seq + prev_seq, bs, H, d) # There is no positional information to be found here\n",
    "        )) # (cs, cs + ps, b, H)\n",
    "        \n",
    "        # position-based attention term ((b) + (d) in the paper)\n",
    "        # this attention is solely based on the position of the key/values\n",
    "        # (i.e. it does not take the content of the key/values into account)\n",
    "        p_tfmd = self.linear_p(pos_embs) # (cs + ps, b, H * d)\n",
    "        position_attn = torch.einsum(\"ibhd,jhd->ijbh\", (\n",
    "                (q_tfmd.view(cur_seq, bs, H, d) + # (b)\n",
    "                 v), # (d): v represents the global (independent of the query)\n",
    "                     # bias towards certain positions\n",
    "                 p_tfmd.view(cur_seq + prev_seq, H, d) # Notice there is not content information\n",
    "                                                        # regarding keys and values here!\n",
    "        )) # (cs, cs + ps, b, H)\n",
    "        \n",
    "        #  Compute positional attention efficiently\n",
    "        position_attn = self._rel_shift(position_attn)\n",
    "        \n",
    "        # the attention is the sum of content-based and position-based attention\n",
    "        attn = content_attn + position_attn\n",
    "\n",
    "        if mask is not None and mask.any().item():\n",
    "            attn = attn.masked_fill(\n",
    "                mask[...,None], -float('inf'))\n",
    "        attn = torch.softmax(attn * self.scale, # rescale to prevent values from exploding\n",
    "                             dim=1) # normalize across the value sequence dimension\n",
    "        attn = self.dropa(attn)\n",
    "        \n",
    "        attn_weighted_values = (torch.einsum(\"ijbh,jbhd->ibhd\",\n",
    "                                           (attn, # (cs, cs + ps, b, H)\n",
    "                                            v_tfmd.view(cur_seq + prev_seq, bs, H, d), # (cs + ps, b, H, d)\n",
    "                                           )) # (cs, b, H, d)\n",
    "                                .contiguous() # we need to change the memory layout to make `view` work\n",
    "                                .view(cur_seq, bs, H * d)) # (cs, b, H * d)\n",
    "\n",
    "        # Project back to input dimension and add residual connection\n",
    "        output = input_ + self.dropo(self.lout(attn_weighted_values))\n",
    "        output = self.norm(output)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test it out to see if it runs successfully"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "mha = MultiHeadAttention(32, 17, n_heads=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "inpt = torch.rand(7, 3, 32)\n",
    "pos = torch.rand(13, 32)\n",
    "mem = torch.rand(6, 3, 32)\n",
    "u, v = torch.rand(4, 17), torch.rand(4, 17)\n",
    "x1 = mha(inpt, pos, mem, u, v)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looks good"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([7, 3, 32])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.6264,  0.3405, -1.9065,  0.1543,  0.0389, -1.6033,  1.4415,  0.4983,\n",
       "          0.7548,  1.0990, -1.1783, -1.3847, -1.7358,  1.4651,  1.0633,  0.2168,\n",
       "         -0.3323,  1.1270,  0.1614,  1.0170,  1.0459, -0.7286,  0.5064, -1.4765,\n",
       "          0.0448, -1.2500,  0.3132, -0.8007,  0.4089,  0.7325, -1.2740,  0.6147],\n",
       "        [ 0.9541,  0.3682, -0.8096,  0.1357, -0.9159, -1.4382, -1.3385,  0.8269,\n",
       "          0.2721, -0.4982,  1.3105, -0.0236, -1.0547, -1.3076,  1.8884, -0.2891,\n",
       "          1.5231,  0.5507, -0.6423,  0.4412,  1.3656,  0.7858, -0.9425, -0.3198,\n",
       "         -0.3162, -0.0086,  1.5257, -1.3216,  1.4492, -0.1750, -0.1669, -1.8291],\n",
       "        [ 1.0132,  0.7205, -0.4221,  0.2952, -1.4117, -0.6182, -1.7520, -1.7426,\n",
       "         -0.4648, -0.2122,  2.0889, -1.3544, -0.1611, -1.0696,  1.3492, -1.0179,\n",
       "          1.2820,  0.8990,  0.7411,  0.8052, -0.5322,  0.6277, -0.2733, -1.0738,\n",
       "         -0.8435,  1.5357,  0.8260, -0.3422, -0.6204,  1.0091,  0.1011,  0.6182]],\n",
       "       grad_fn=<SelectBackward>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x1[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building the decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To construct the decoder block, all we need in addition to the MultiHeadAttention layer is the Positionwise Feed Forward layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://i2.wp.com/mlexplained.com/wp-content/uploads/2017/12/%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%BC%E3%83%B3%E3%82%B7%E3%83%A7%E3%83%83%E3%83%88-2017-12-29-19.14.41.png?w=273)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionwiseFF(nn.Module):\n",
    "    def __init__(self, d_input, d_inner, dropout):\n",
    "        super().__init__()\n",
    "\n",
    "        self.d_input = d_input\n",
    "        self.d_inner = d_inner\n",
    "        self.dropout = dropout\n",
    "        self.ff = nn.Sequential(\n",
    "            nn.Linear(d_input, d_inner), nn.ReLU(inplace=True),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(d_inner, d_input),\n",
    "            nn.Dropout(dropout),\n",
    "        )\n",
    "        self.layer_norm = nn.LayerNorm(d_input)\n",
    "\n",
    "    def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)\n",
    "               ) -> torch.FloatTensor: # (cur_seq, bs, d_input)\n",
    "        ff_out = self.ff(input_)\n",
    "        output = self.layer_norm(input_ + ff_out)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can implement the decoder block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecoderBlock(nn.Module):\n",
    "    def __init__(self, n_heads, d_input, \n",
    "                 d_head_inner, d_ff_inner,\n",
    "                 dropout, dropouta=0.):\n",
    "        super().__init__()\n",
    "        self.mha = MultiHeadAttention(d_input, d_head_inner, n_heads=n_heads, \n",
    "                                      dropout=dropout, dropouta=dropouta)\n",
    "        self.ff = PositionwiseFF(d_input, d_ff_inner, dropout)\n",
    "            \n",
    "    def forward(self, input_: torch.FloatTensor, # (cur_seq, bs, d_input)\n",
    "                pos_embs: torch.FloatTensor, # (cur_seq + prev_seq, d_input),\n",
    "                u: torch.FloatTensor, # (H, d_input), \n",
    "                v: torch.FloatTensor, # (H, d_input),\n",
    "                mask=None,\n",
    "                mems=None,\n",
    "               ):\n",
    "        return self.ff(self.mha(input_, pos_embs, mems, u, v, mask=mask))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The full Transformer XL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now with these components in place, we can build the full Transformer XL model.\n",
    "\n",
    "Aside from what we mentioned above, one common trick in language modeling that we haven't covered yet is tying the input embedding matrix $ E $ and output projection matrix $ P $. Remember, a language model predicts the next token in a sequence, so its output dimension is $\\mathbb{R}^{|V|}$  where $|V|$ is the vocab size. If we constrain the penultimate layer output to be the same dimension as the embeddings $d$, the embedding matrix $ E $ will be of shape $ \\mathbb{R}^{|V| \\times d}$ and the output projection matrix $ P $ will be of shape $ \\mathbb{R}^{d \\times |V|} $.\n",
    "\n",
    "In [this paper](https://arxiv.org/abs/1608.05859), the authors found that constraining the matrices such that $ P = E^T $ improved performance while greatly reducing the total parameter count (and thus memory usage!) of the model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead of simply using the exact same weights, we scale the embeddings by the embedding dim. This trick is included in the codebase but not mentioned in the paper as far as I can tell. If you're aware of a paper where this trick was originally introduced, please let me know!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "class StandardWordEmbedding(nn.Module):\n",
    "    def __init__(self, num_embeddings, embedding_dim,\n",
    "                div_val=1, sample_softmax=False):\n",
    "        super().__init__()\n",
    "        self.num_embeddings = num_embeddings\n",
    "        self.embedding_dim = embedding_dim\n",
    "        self.embedding = nn.Embedding(num_embeddings, embedding_dim)\n",
    "        self.scale = embedding_dim ** 0.5\n",
    "\n",
    "    def forward(self, input_: torch.LongTensor):\n",
    "        return self.embedding(input_) * self.scale"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, all we need to do is to put everything we have implemented above together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerXL(nn.Module):\n",
    "    def __init__(self, num_embeddings, n_layers, n_heads, \n",
    "                 d_model, d_head_inner, d_ff_inner,\n",
    "                 dropout=0.1, dropouta=0., \n",
    "                 seq_len: int=0, mem_len: int=0):\n",
    "        super().__init__()\n",
    "        self.n_layers,self.n_heads,self.d_model,self.d_head_inner,self.d_ff_inner = \\\n",
    "            n_layers,n_heads,d_model,d_head_inner,d_ff_inner\n",
    "        # Embedding layers\n",
    "        self.word_embs = StandardWordEmbedding(num_embeddings, d_model)\n",
    "        self.pos_embs = PositionalEmbedding(d_model)\n",
    "        # Core transformer\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.layers = nn.ModuleList([DecoderBlock(n_heads, d_model, d_head_inner=d_head_inner,\n",
    "                                                  d_ff_inner=d_ff_inner,\n",
    "                                                  dropout=dropout, dropouta=dropouta)\n",
    "                                     for _ in range(n_layers)])\n",
    "\n",
    "        # tie weights\n",
    "        self.output_projection = nn.Linear(d_model, num_embeddings)\n",
    "        self.output_projection.weight = self.word_embs.embedding.weight\n",
    "        self.loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "        self.seq_len, self.mem_len = seq_len, mem_len\n",
    "        \n",
    "        # u and v are global parameters: maybe changing these to per-head parameters\n",
    "        # might help performance?\n",
    "        self.u, self.v = (nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)),\n",
    "                          nn.Parameter(torch.Tensor(self.n_heads, self.d_head_inner)))\n",
    "        \n",
    "    def init_memory(self, device=torch.device(\"cpu\")) -> torch.FloatTensor:\n",
    "        return [torch.empty(0, dtype=torch.float).to(device) for _ in range(self.n_layers+1)]\n",
    "    \n",
    "    def update_memory(self, \n",
    "            previous_memory: List[torch.FloatTensor], \n",
    "            hidden_states: List[torch.FloatTensor],\n",
    "        ):\n",
    "        assert len(hidden_states) == len(previous_memory)\n",
    "        mem_len, seq_len = previous_memory[0].size(0), hidden_states[0].size(0)\n",
    "\n",
    "        # For the updated memory, we use the most recent `self.mem_len`\n",
    "        # states, including the previous memory\n",
    "        # In other words, if `seq_len` < `self.mem_len` some of the previous memory\n",
    "        # will carry over to the next memory\n",
    "        with torch.no_grad():\n",
    "            new_memory = []\n",
    "            end_idx = mem_len + seq_len\n",
    "            beg_idx = max(0, end_idx - self.mem_len)\n",
    "            for m, h in zip(previous_memory, hidden_states):\n",
    "                cat = torch.cat([m, h], dim=0) # (mem_len + seq_len, bs, d)\n",
    "                new_memory.append(cat[beg_idx:end_idx].detach()) # (self.mem_len, bs, d)\n",
    "        return new_memory\n",
    "    \n",
    "    def reset_length(self, seq_len, ext_len, mem_len):\n",
    "        self.seq_len = seq_len\n",
    "        self.mem_len = mem_len\n",
    "    \n",
    "    def forward(self, idxs: torch.LongTensor, # (cs, bs)\n",
    "                target: torch.LongTensor, # (cs, bs)\n",
    "                memory: Optional[List[torch.FloatTensor]]=None,\n",
    "               ) -> Dict[str, torch.Tensor]:\n",
    "        if memory is None: \n",
    "            memory: List[torch.FloatTensor] = self.init_memory(idxs.device)\n",
    "        assert len(memory) == len(self.layers) + 1\n",
    "        cur_seq, bs = idxs.size()\n",
    "        prev_seq = memory[0].size(0)\n",
    "        \n",
    "        # Construct attention mask\n",
    "        dec_attn_mask = torch.triu(\n",
    "            torch.ones((cur_seq, cur_seq + prev_seq)),\n",
    "            diagonal=1 + prev_seq,\n",
    "        ).byte()[...,None].to(idxs.device)\n",
    "        \n",
    "        word_embs = self.drop(self.word_embs(idxs))\n",
    "        pos_idxs = torch.arange(cur_seq + prev_seq - 1, -1, -1.0, dtype=torch.float).to(word_embs.device)\n",
    "        pos_embs = self.drop(self.pos_embs(pos_idxs))\n",
    "        \n",
    "        # Main part of forward pass\n",
    "        hidden_states = [word_embs]\n",
    "        layer_out = word_embs\n",
    "        for mem, layer in zip(memory, self.layers):\n",
    "            layer_out = layer(layer_out, pos_embs, self.u, self.v, \n",
    "                              mask=dec_attn_mask, mems=mem)\n",
    "            hidden_states.append(layer_out)\n",
    "        \n",
    "        logits = self.output_projection(self.drop(layer_out))        \n",
    "        loss = self.loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))\n",
    "        \n",
    "        # Update memory \n",
    "        # Ensure the memory is treated as a constant\n",
    "        # and we do not back propagate through them\n",
    "        new_memory = self.update_memory(memory, hidden_states)\n",
    "        return {\"loss\": loss, \"logits\": logits, \"memory\": new_memory}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = TransformerXL(1000, 4, 3, 32, 17, 71, mem_len=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, let's feed some random inputs to confirm the model is working."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': tensor(22.7983, grad_fn=<NllLossBackward>),\n",
       " 'logits': tensor([[[  2.3893,   1.9681,  -2.6914,  ...,  -0.8093,  -2.1468,  -1.5656],\n",
       "          [ -6.9673,  -5.4040,   1.5984,  ...,  -8.6291,   9.6236,  -3.4973],\n",
       "          [ -6.9776,  -3.5877,   0.9394,  ..., -12.9488,   4.1804,   4.1679],\n",
       "          ...,\n",
       "          [ -7.8283,   3.3389,  -1.6314,  ...,  -3.7291,   5.9029,  -1.5640],\n",
       "          [  5.7904,   4.9521, -11.5190,  ...,  -1.1907,   5.1361,  -0.6909],\n",
       "          [ -1.2347,   4.8922,  -2.4506,  ...,   1.5876,   5.6535,  -3.9682]],\n",
       " \n",
       "         [[  2.7837,  -1.6279,   5.5712,  ...,   2.1291,   3.5525,   8.2256],\n",
       "          [  4.4832,  -1.4655,  -7.8257,  ...,  -0.7932,  -1.3533,   1.5421],\n",
       "          [ -6.0790,  -7.7122,  -3.2280,  ..., -14.7921,  -1.0191,  19.7041],\n",
       "          ...,\n",
       "          [ -8.7061,  -2.2215, -11.9336,  ...,  -5.7044,   0.8774,  -1.9217],\n",
       "          [ -2.3067,  -3.2391,  -1.1071,  ...,   0.2583,  -5.1745, -15.7446],\n",
       "          [ -0.6714,   8.3893,  -1.0352,  ...,   1.4273,  18.6013,  -3.9901]],\n",
       " \n",
       "         [[  1.1127, -10.2239,   0.9317,  ..., -10.6415,   4.1240,   5.4683],\n",
       "          [  9.5186,  -8.7525,  -0.4702,  ...,  -7.4802,  -1.4103,   8.6529],\n",
       "          [-10.7924,  -2.3530,  -1.1558,  ...,  -9.1787,  -0.4928,  -0.9864],\n",
       "          ...,\n",
       "          [  6.2984,   0.5173,   6.5342,  ...,  -4.1923,  -3.8633,  -0.7991],\n",
       "          [ -9.1815,  -5.3329,  -7.6900,  ...,  -8.8483,   2.2137,   1.6496],\n",
       "          [ -5.3949,   4.4151,  -2.7581,  ...,  -0.5034,   7.8452,  -1.1456]],\n",
       " \n",
       "         [[ -1.8854,  -3.1880,   5.2209,  ...,   4.6219,   5.8592,   6.7511],\n",
       "          [  1.7416,   0.8599,   3.7392,  ...,   2.1722,   2.7656, -12.5893],\n",
       "          [ -4.7122,   5.6147,  -0.2521,  ..., -13.2473,   2.8461,   4.8401],\n",
       "          ...,\n",
       "          [ -9.8932,   4.5135,   9.3876,  ...,   2.2749,   4.2054,   8.6749],\n",
       "          [  2.8098,  -4.5236,   2.9966,  ...,  -9.3179,  -3.0171,   2.8091],\n",
       "          [ -4.3847,  -8.1687, -10.1101,  ...,  -9.0084,   0.1503,   1.9549]],\n",
       " \n",
       "         [[ -3.9730,   5.7617,   0.9482,  ...,   5.6215,  -6.8315,  -3.8019],\n",
       "          [  2.0387,  -1.3776,   4.5145,  ...,  -2.6593,  -7.4511,  -4.3874],\n",
       "          [  8.2944,  -2.8881,  -0.5630,  ...,  -4.1937,  -3.9678,  -5.0382],\n",
       "          ...,\n",
       "          [  9.9117,  -5.5987,   3.6440,  ...,  -3.4378,  -4.1492,   4.0374],\n",
       "          [ -6.7007,   0.7956, -11.9654,  ...,  -3.4386,  -6.1001,   2.1052],\n",
       "          [ -1.1894,  -3.6629,  10.9244,  ...,  -5.9781,   5.6562,   8.7496]]],\n",
       "        grad_fn=<AddBackward0>),\n",
       " 'memory': [tensor([[[ 2.1017e+00,  7.0014e+00, -1.1576e+01,  ..., -4.1735e+00,\n",
       "             7.1704e+00, -9.6242e-01],\n",
       "           [-4.9919e-01,  0.0000e+00, -2.2360e+00,  ..., -7.4936e-01,\n",
       "             5.5746e-01,  4.2537e+00],\n",
       "           [ 7.5275e+00, -0.0000e+00,  6.4217e+00,  ..., -1.0603e+00,\n",
       "             2.6083e-01,  0.0000e+00],\n",
       "           ...,\n",
       "           [-0.0000e+00, -5.6473e+00, -1.2624e+01,  ...,  5.6180e-01,\n",
       "             1.0439e+00,  7.7785e+00],\n",
       "           [ 5.5092e+00,  3.7028e+00, -0.0000e+00,  ...,  1.2419e+01,\n",
       "            -1.2681e+00, -3.8133e+00],\n",
       "           [ 5.9784e+00, -1.2864e+00, -1.6390e+01,  ..., -4.7940e+00,\n",
       "             7.9133e+00, -7.4816e+00]],\n",
       "  \n",
       "          [[-1.2514e+01,  3.6603e+00, -1.1988e+01,  ..., -2.5402e+00,\n",
       "            -6.0357e+00, -1.6756e+01],\n",
       "           [-1.7803e+00,  2.9728e+00,  1.4087e+00,  ...,  4.8081e+00,\n",
       "             1.9607e+00,  5.2030e+00],\n",
       "           [ 2.3099e+00, -5.5314e+00,  7.6574e+00,  ...,  7.1721e+00,\n",
       "            -4.0299e-01, -0.0000e+00],\n",
       "           ...,\n",
       "           [-6.4647e+00,  9.5205e+00, -9.9825e+00,  ...,  5.7900e+00,\n",
       "             1.3428e+01,  7.6066e+00],\n",
       "           [-3.3331e+00,  5.6945e+00, -3.7368e+00,  ...,  0.0000e+00,\n",
       "            -2.3161e+00, -6.7402e+00],\n",
       "           [-5.6126e+00,  6.1271e+00,  1.1794e+00,  ..., -1.2898e-01,\n",
       "             1.8484e+00, -2.3268e-01]],\n",
       "  \n",
       "          [[-1.0437e+01, -0.0000e+00, -5.1754e-01,  ..., -1.7993e+00,\n",
       "             1.6501e+00,  5.9062e+00],\n",
       "           [ 1.0803e+01, -9.8891e+00, -1.6223e+00,  ..., -3.4334e+00,\n",
       "             3.3257e+00, -1.2487e+00],\n",
       "           [-2.0990e-01,  5.3983e+00, -0.0000e+00,  ...,  7.3918e+00,\n",
       "            -3.5076e+00,  1.8095e+00],\n",
       "           ...,\n",
       "           [ 3.4760e+00, -2.3750e+00,  1.5633e+00,  ..., -1.1462e+01,\n",
       "            -0.0000e+00, -3.0820e+00],\n",
       "           [ 3.1093e+00, -0.0000e+00,  3.0700e+00,  ..., -5.9241e+00,\n",
       "             0.0000e+00, -2.5284e+00],\n",
       "           [-7.4560e+00,  1.1351e+00,  6.4418e+00,  ...,  9.4216e+00,\n",
       "             1.3614e+00,  1.3476e+00]],\n",
       "  \n",
       "          [[ 0.0000e+00, -0.0000e+00, -2.3496e+00,  ...,  4.0335e+00,\n",
       "             1.3728e+01,  4.7616e-01],\n",
       "           [-9.9943e-01,  2.9256e+00, -2.8993e+00,  ..., -0.0000e+00,\n",
       "             6.8954e-01,  6.7603e+00],\n",
       "           [-6.9932e+00, -5.8414e-01,  1.7446e+00,  ..., -2.0840e+00,\n",
       "             3.7240e+00, -3.9543e+00],\n",
       "           ...,\n",
       "           [-1.5157e+00,  0.0000e+00,  5.2591e+00,  ...,  3.7703e-01,\n",
       "            -0.0000e+00, -1.3468e+00],\n",
       "           [-1.0982e+00, -0.0000e+00, -4.7111e+00,  ...,  4.7480e+00,\n",
       "             4.9005e+00, -2.4251e+00],\n",
       "           [-2.3143e+00,  0.0000e+00, -1.9084e+00,  ...,  1.6107e+00,\n",
       "             2.3575e+00, -5.4302e+00]],\n",
       "  \n",
       "          [[ 0.0000e+00,  1.3610e+01, -4.1463e+00,  ..., -4.5901e+00,\n",
       "             2.3292e+00, -1.0738e+01],\n",
       "           [ 4.0665e+00,  0.0000e+00,  7.8862e+00,  ...,  0.0000e+00,\n",
       "            -9.0152e+00, -2.4664e+00],\n",
       "           [ 1.5953e+01, -1.2290e+00,  0.0000e+00,  ..., -0.0000e+00,\n",
       "             6.9739e+00,  0.0000e+00],\n",
       "           ...,\n",
       "           [-1.5245e+00,  1.4178e-02,  3.7615e-01,  ..., -1.5414e+00,\n",
       "             5.3257e+00,  6.1866e+00],\n",
       "           [-7.2876e+00,  1.2365e+01, -2.8329e+00,  ...,  5.8748e+00,\n",
       "            -6.5736e+00, -1.1682e+01],\n",
       "           [-0.0000e+00, -4.1767e-01, -2.7669e-01,  ...,  5.1264e+00,\n",
       "             2.3285e+00,  5.2544e+00]]]),\n",
       "  tensor([[[ 1.3826,  0.9239, -2.2534,  ...,  0.5224,  1.6166,  0.1766],\n",
       "           [-0.0735,  0.2080, -0.4732,  ..., -0.4010, -0.2503,  0.4283],\n",
       "           [ 1.8843, -0.6893,  0.9431,  ..., -0.2929, -0.1904,  0.0774],\n",
       "           ...,\n",
       "           [-0.5283, -1.1426, -1.9640,  ...,  0.1228,  0.3225,  0.7912],\n",
       "           [ 2.0171,  0.6775,  0.0373,  ...,  0.5712, -0.4692, -1.0398],\n",
       "           [ 1.3282, -0.1157, -2.2052,  ..., -0.1816,  1.3291, -1.3611]],\n",
       "  \n",
       "          [[-1.0597,  0.6241, -1.3270,  ..., -0.0906, -0.0414, -1.6208],\n",
       "           [-0.4735,  0.9648, -0.1781,  ...,  1.0080, -0.0339,  0.7943],\n",
       "           [ 1.3104, -1.8248,  0.5972,  ...,  0.5662, -0.6202,  0.1132],\n",
       "           ...,\n",
       "           [-0.2953,  0.9277, -1.1543,  ...,  0.7062,  2.1062,  0.5092],\n",
       "           [-0.3843,  0.3064, -0.3782,  ..., -0.0153, -1.1835, -1.1707],\n",
       "           [-1.0753,  0.9251, -0.4859,  ..., -0.0698,  0.5005, -1.1221]],\n",
       "  \n",
       "          [[-1.0846, -0.6314, -0.0176,  ..., -0.6004,  0.0514,  2.0411],\n",
       "           [ 1.7022, -1.3314, -0.0765,  ..., -1.7647,  0.3343, -0.8403],\n",
       "           [ 0.9930,  0.6483, -0.3341,  ...,  1.0334, -1.5398,  0.7327],\n",
       "           ...,\n",
       "           [ 1.3533, -1.2107,  0.4709,  ..., -1.0716, -0.0823, -0.8666],\n",
       "           [ 1.6170,  0.4248,  0.3163,  ..., -1.3363, -0.8826, -0.7667],\n",
       "           [ 0.0648,  0.1811,  1.0545,  ...,  1.7590,  0.5024,  0.3842]],\n",
       "  \n",
       "          [[ 0.0324, -1.0944, -1.1538,  ...,  0.4320,  2.1470,  0.4528],\n",
       "           [-0.0979,  0.6155, -0.7837,  ..., -0.3645, -0.2686,  1.5890],\n",
       "           [-0.3416,  0.3403,  0.1237,  ..., -0.8141,  0.5855, -0.7689],\n",
       "           ...,\n",
       "           [-0.0435, -1.0859,  1.3180,  ...,  0.3539,  0.3149, -0.3087],\n",
       "           [-0.1995, -0.3531, -0.5246,  ...,  0.6699,  0.5467,  0.2343],\n",
       "           [ 0.7265,  0.1116, -0.4563,  ...,  0.6268,  0.6353, -1.0494]],\n",
       "  \n",
       "          [[ 0.1539,  2.0148, -0.7285,  ..., -0.6078,  0.3653, -1.3313],\n",
       "           [ 1.5128,  0.1665,  1.2747,  ...,  0.1989, -0.8545, -0.3040],\n",
       "           [ 3.4671, -0.3973, -0.4849,  ...,  0.2124,  0.7675, -0.5249],\n",
       "           ...,\n",
       "           [ 0.1298, -0.5067,  0.1377,  ...,  0.6606,  0.6893,  1.1767],\n",
       "           [-0.5558,  1.9915, -0.5002,  ...,  0.9498, -1.0524, -1.6267],\n",
       "           [ 0.8146, -1.0068, -0.6703,  ...,  0.8581,  0.5871,  1.1277]]]),\n",
       "  tensor([[[ 1.6346e+00,  9.5465e-01, -2.0691e+00,  ...,  4.4820e-01,\n",
       "             2.2910e+00,  4.9902e-01],\n",
       "           [ 2.4900e-01,  1.8973e-01, -3.3778e-01,  ..., -8.5209e-01,\n",
       "            -5.5017e-01,  5.1561e-01],\n",
       "           [ 1.4326e+00, -8.8077e-01,  1.0193e+00,  ..., -3.8159e-01,\n",
       "             8.8002e-03,  5.8506e-01],\n",
       "           ...,\n",
       "           [-6.0461e-01, -7.7426e-01, -1.6539e+00,  ...,  1.4462e-01,\n",
       "            -1.6225e-01,  5.6319e-01],\n",
       "           [ 1.5300e+00,  7.9754e-01, -1.2463e-01,  ...,  1.5487e-01,\n",
       "             4.7096e-02, -2.2055e-01],\n",
       "           [ 1.3635e+00,  2.2628e-01, -2.5490e+00,  ..., -2.4268e-01,\n",
       "             1.9050e+00, -3.0230e-01]],\n",
       "  \n",
       "          [[-7.8208e-01,  6.8655e-01, -1.4995e+00,  ..., -1.5197e-01,\n",
       "             7.7530e-01, -1.1415e+00],\n",
       "           [ 2.8964e-01,  1.5075e+00,  2.6526e-02,  ...,  9.5593e-01,\n",
       "             1.4820e-01,  5.8030e-01],\n",
       "           [ 1.0642e+00, -2.0875e+00,  7.8689e-01,  ...,  9.4528e-01,\n",
       "            -8.7262e-01,  2.6703e-01],\n",
       "           ...,\n",
       "           [ 9.6837e-02,  1.2272e+00, -9.9427e-01,  ...,  5.3825e-01,\n",
       "             2.2138e+00,  8.4783e-01],\n",
       "           [-4.6253e-01,  7.3703e-01, -4.7150e-01,  ...,  2.4403e-01,\n",
       "            -1.1370e+00, -5.6916e-01],\n",
       "           [-1.0717e+00,  1.2868e+00, -3.9602e-01,  ...,  4.7487e-01,\n",
       "             7.4984e-01, -1.3079e-01]],\n",
       "  \n",
       "          [[-9.9346e-01, -4.6816e-01, -4.2036e-01,  ...,  1.9013e-01,\n",
       "             5.6324e-01,  1.7506e+00],\n",
       "           [ 1.4775e+00, -7.2553e-01, -3.9095e-01,  ..., -1.6459e+00,\n",
       "            -5.1008e-01, -1.1018e+00],\n",
       "           [ 1.0286e+00,  6.4774e-02, -1.1253e-02,  ...,  8.0176e-01,\n",
       "            -1.8767e+00,  9.6665e-01],\n",
       "           ...,\n",
       "           [ 6.5649e-01, -9.1886e-01,  4.0203e-01,  ..., -8.7860e-01,\n",
       "             2.5020e-01, -2.3505e-02],\n",
       "           [ 1.5418e+00,  4.3293e-01,  6.1181e-01,  ..., -1.0745e+00,\n",
       "            -6.0872e-01, -2.7264e-01],\n",
       "           [ 1.3029e-01,  5.1017e-01,  8.4829e-01,  ...,  1.6529e+00,\n",
       "             5.5392e-01,  9.1532e-01]],\n",
       "  \n",
       "          [[ 3.3911e-02, -8.4844e-01, -1.0074e+00,  ...,  4.9299e-01,\n",
       "             2.4838e+00,  3.1425e-01],\n",
       "           [ 6.1697e-01,  7.7079e-01, -3.5887e-01,  ..., -2.5764e-01,\n",
       "            -9.5384e-01,  1.4135e+00],\n",
       "           [-3.3163e-01,  3.4299e-01,  5.2204e-01,  ..., -7.2144e-01,\n",
       "             6.0687e-01,  8.2380e-02],\n",
       "           ...,\n",
       "           [-1.9737e-01, -9.4352e-01,  1.7135e+00,  ...,  3.5844e-02,\n",
       "             7.1961e-01,  9.6403e-02],\n",
       "           [-1.1394e-01, -1.4094e-02, -1.6593e-01,  ..., -1.3911e-01,\n",
       "             1.5920e-01,  4.5853e-01],\n",
       "           [ 7.0644e-01,  6.4221e-01, -9.2809e-01,  ...,  1.2030e+00,\n",
       "             5.6345e-01, -1.0111e+00]],\n",
       "  \n",
       "          [[ 8.9835e-01,  1.9890e+00, -8.9823e-01,  ..., -8.1031e-01,\n",
       "             5.3538e-01, -1.2763e+00],\n",
       "           [ 1.7208e+00,  2.1950e-01,  1.2462e+00,  ..., -2.1622e-01,\n",
       "            -7.6752e-01,  1.7855e-03],\n",
       "           [ 2.8792e+00, -3.1874e-01, -2.0828e-01,  ...,  1.0098e-01,\n",
       "             7.4217e-01, -2.6374e-02],\n",
       "           ...,\n",
       "           [ 1.7004e-02, -6.7681e-01,  2.5581e-01,  ...,  3.3252e-01,\n",
       "             8.9061e-01,  1.4128e+00],\n",
       "           [ 1.5477e-02,  2.3536e+00,  7.8383e-02,  ...,  3.7001e-01,\n",
       "            -8.4431e-01, -9.0505e-01],\n",
       "           [ 5.4680e-01, -7.6991e-01, -4.1328e-01,  ...,  1.0264e+00,\n",
       "             9.9558e-01,  1.6870e+00]]]),\n",
       "  tensor([[[ 1.3837e+00,  1.0989e+00, -2.2840e+00,  ...,  6.0480e-01,\n",
       "             1.8528e+00,  1.2792e-01],\n",
       "           [-2.2875e-01,  9.7942e-02, -7.5347e-01,  ..., -5.9109e-01,\n",
       "            -6.4622e-01,  4.3441e-01],\n",
       "           [ 1.3624e+00, -5.5102e-01,  1.2233e+00,  ..., -4.0880e-01,\n",
       "            -1.5982e-01,  5.9552e-01],\n",
       "           ...,\n",
       "           [-9.3327e-01, -1.0077e+00, -2.1290e+00,  ...,  3.9902e-01,\n",
       "             5.4323e-01,  1.4707e-01],\n",
       "           [ 9.5668e-01,  1.4832e+00, -7.4227e-01,  ...,  4.8545e-01,\n",
       "             5.3537e-02, -9.6289e-01],\n",
       "           [ 1.2578e+00,  6.9142e-02, -2.4296e+00,  ...,  9.2327e-02,\n",
       "             1.6102e+00, -1.2208e+00]],\n",
       "  \n",
       "          [[-8.3147e-01,  1.2378e+00, -1.5615e+00,  ...,  3.6822e-01,\n",
       "             5.2679e-01, -1.5067e+00],\n",
       "           [ 2.0892e-01,  1.4041e+00, -2.0698e-01,  ...,  1.2122e+00,\n",
       "             1.3742e-01,  6.8231e-01],\n",
       "           [ 1.4700e+00, -2.0746e+00,  8.6280e-01,  ...,  7.7288e-01,\n",
       "            -1.0128e+00,  1.9533e-01],\n",
       "           ...,\n",
       "           [-4.0518e-01,  1.0060e+00, -1.2097e+00,  ...,  6.6154e-01,\n",
       "             2.2548e+00,  8.4912e-01],\n",
       "           [-9.0720e-01,  6.4946e-01, -7.6629e-01,  ...,  8.3773e-02,\n",
       "            -2.2143e-01, -5.6469e-01],\n",
       "           [-1.5342e+00,  1.2075e+00, -4.9365e-01,  ...,  7.7658e-01,\n",
       "             8.2622e-01, -5.1317e-01]],\n",
       "  \n",
       "          [[-1.0310e+00, -6.3179e-01, -8.8332e-01,  ...,  8.8619e-02,\n",
       "             6.0848e-01,  1.5056e+00],\n",
       "           [ 1.5093e+00, -1.1469e+00, -4.5369e-01,  ..., -9.8717e-01,\n",
       "            -6.3974e-01, -1.0479e+00],\n",
       "           [ 1.3946e+00,  1.5321e-01,  4.9116e-01,  ...,  1.0110e+00,\n",
       "            -2.0338e+00,  6.2318e-01],\n",
       "           ...,\n",
       "           [ 2.7276e-02, -1.0978e+00, -1.2747e-03,  ..., -6.6941e-02,\n",
       "             3.4110e-01, -3.0183e-01],\n",
       "           [ 1.0870e+00,  3.6189e-01,  3.5213e-01,  ..., -7.1679e-01,\n",
       "            -1.2220e-01, -4.6191e-01],\n",
       "           [ 1.5815e-01,  5.9903e-01,  4.3838e-01,  ...,  1.8811e+00,\n",
       "             6.4441e-01,  5.0634e-01]],\n",
       "  \n",
       "          [[ 1.4628e-02, -6.5199e-01, -1.3255e+00,  ...,  1.1559e+00,\n",
       "             1.8737e+00, -1.2272e-01],\n",
       "           [ 5.7514e-01,  4.6596e-01, -3.4672e-01,  ...,  3.9621e-01,\n",
       "            -1.3981e+00,  1.2205e+00],\n",
       "           [-6.2185e-01,  8.9906e-01,  6.5405e-01,  ..., -8.0875e-01,\n",
       "            -3.7445e-02,  9.1749e-02],\n",
       "           ...,\n",
       "           [-5.9502e-01, -6.5149e-01,  1.2600e+00,  ...,  7.0306e-01,\n",
       "             4.3226e-01, -3.6069e-03],\n",
       "           [-2.5798e-01, -3.2951e-01, -6.0422e-01,  ...,  7.6340e-02,\n",
       "             1.9422e-01,  5.2702e-02],\n",
       "           [ 2.5915e-01,  4.1826e-01, -8.1916e-01,  ...,  1.3079e+00,\n",
       "             5.0282e-01, -8.8068e-01]],\n",
       "  \n",
       "          [[ 8.9606e-01,  2.7360e+00, -7.7418e-01,  ..., -6.3446e-01,\n",
       "             4.9931e-01, -1.4507e+00],\n",
       "           [ 1.6607e+00,  3.6149e-01,  7.9179e-01,  ...,  5.6409e-02,\n",
       "            -1.4083e+00,  3.5620e-02],\n",
       "           [ 2.6172e+00,  4.9535e-02, -1.5108e-01,  ...,  7.3331e-02,\n",
       "             1.2261e+00, -1.4635e-01],\n",
       "           ...,\n",
       "           [-3.5466e-01, -6.0469e-01, -1.8319e-02,  ...,  9.4671e-01,\n",
       "             7.6173e-01,  1.3294e+00],\n",
       "           [-6.0030e-01,  2.4628e+00, -2.9686e-02,  ...,  2.3356e-01,\n",
       "            -6.3912e-01, -9.3690e-01],\n",
       "           [ 5.5543e-01, -3.8047e-01, -3.4127e-01,  ...,  1.9001e+00,\n",
       "             6.9048e-01,  1.9095e+00]]]),\n",
       "  tensor([[[ 1.9261,  0.3662, -2.2351,  ..., -0.2479,  2.0208, -0.3038],\n",
       "           [ 0.2193,  0.2397, -0.6246,  ..., -0.1977, -0.7800,  0.0299],\n",
       "           [ 1.3310, -0.8811,  0.9491,  ...,  0.1043, -0.0846,  0.5337],\n",
       "           ...,\n",
       "           [-0.3900, -1.2184, -2.0538,  ...,  0.3454,  0.7886, -0.4477],\n",
       "           [ 1.7995,  1.5179, -1.2692,  ...,  0.6605,  0.3598, -1.2437],\n",
       "           [ 1.9007,  0.0435, -2.2540,  ...,  0.2233,  1.4315, -0.8018]],\n",
       "  \n",
       "          [[-0.2991,  0.3210, -2.0743,  ..., -0.1376,  0.6999, -1.8585],\n",
       "           [ 0.2477,  1.0192, -0.5260,  ...,  1.0971, -0.3136,  0.0693],\n",
       "           [ 1.9090, -1.7206,  0.3304,  ...,  0.8011, -1.1116,  0.2133],\n",
       "           ...,\n",
       "           [ 0.3987,  0.1333, -0.8556,  ...,  0.7316,  2.1006,  0.4276],\n",
       "           [ 0.1955,  0.6050, -1.0433,  ...,  0.5461,  0.0905, -1.1321],\n",
       "           [-1.1648,  0.8397, -0.7285,  ...,  0.9650,  1.0187, -0.9154]],\n",
       "  \n",
       "          [[-0.2072, -0.9251, -0.4631,  ...,  0.0555,  0.9139,  0.6636],\n",
       "           [ 1.8273, -0.8446, -0.6740,  ..., -0.9765, -0.6065, -1.1694],\n",
       "           [ 1.1242,  0.0836,  0.1622,  ...,  1.1291, -2.1867,  0.2463],\n",
       "           ...,\n",
       "           [ 0.2276, -1.3735, -0.2455,  ..., -0.4645, -0.0795, -0.4198],\n",
       "           [ 1.3959,  0.1959, -0.1209,  ..., -0.4528, -0.3387, -0.9129],\n",
       "           [ 1.1302,  0.4864,  0.4106,  ...,  1.6754,  0.6609,  0.2517]],\n",
       "  \n",
       "          [[ 0.7390, -1.2320, -1.4019,  ...,  0.7018,  1.4972, -0.7071],\n",
       "           [ 0.4820,  0.7413, -0.4056,  ...,  0.6629, -1.9494,  1.4256],\n",
       "           [-0.0761, -0.0412,  0.4128,  ..., -0.6761, -0.7692, -0.3830],\n",
       "           ...,\n",
       "           [-0.1002, -1.1125,  0.8038,  ...,  0.7448, -0.3724, -0.9875],\n",
       "           [ 0.6388, -1.2257, -0.8242,  ...,  0.2262, -0.1097, -0.1556],\n",
       "           [ 1.1622,  0.5248, -0.8697,  ...,  1.2056,  0.7714, -1.3186]],\n",
       "  \n",
       "          [[ 1.1718,  2.1725, -1.0235,  ..., -0.7244,  0.2784, -2.0104],\n",
       "           [ 1.6596,  0.0280,  0.4212,  ...,  0.4771, -1.8212, -0.2746],\n",
       "           [ 2.8073, -0.2165, -0.2219,  ...,  0.3132,  0.9727, -0.2467],\n",
       "           ...,\n",
       "           [ 0.3232, -0.8092, -0.2024,  ...,  0.5945, -0.4198,  0.5399],\n",
       "           [-0.2653,  2.1102, -0.0695,  ...,  0.6733, -0.5012, -1.4819],\n",
       "           [ 0.7600, -0.0512, -0.2453,  ...,  1.5513,  0.3439,  1.1348]]])]}"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idxs = torch.randint(1000, (5, 9))\n",
    "tgts = torch.randint(1000, (5, 9))\n",
    "transformer(idxs, tgts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training the Transformer XL"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's move on to actually training the Transformer XL."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "TESTING = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using the following configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Config(dict):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        for k, v in kwargs.items():\n",
    "            setattr(self, k, v)\n",
    "    \n",
    "    def set(self, key, val):\n",
    "        self[key] = val\n",
    "        setattr(self, key, val)\n",
    "        \n",
    "    def update(self, dct):\n",
    "        for k, v in dct.items():\n",
    "            self.set(k, v)\n",
    "\n",
    "# We will use prime numbers to ensure our implementation\n",
    "# is actually correct\n",
    "config = Config(\n",
    "    seed=101,\n",
    "    debug=False,\n",
    "    warmup_step=0,\n",
    "    # Check default params\n",
    "    min_lr=0., \n",
    "    dropouta=0.,\n",
    "    clip=0.25,\n",
    "    log_interval=200,\n",
    "    eval_interval=100,\n",
    ")\n",
    "\n",
    "if TESTING:\n",
    "    config.update(dict(\n",
    "        debug=True,\n",
    "        lr=0.00025,\n",
    "        bs=8,\n",
    "        epochs=2,\n",
    "        max_step=10000, # shorten for testing\n",
    "        n_layers=4,\n",
    "        n_heads=3,\n",
    "        d_model=32,\n",
    "        d_head_inner=17,\n",
    "        d_ff_inner=71,\n",
    "        dropout=0.1,\n",
    "        train_bptt=33,\n",
    "        eval_bptt=41,\n",
    "        mem_len=41,\n",
    "        eval_mem_len=63,\n",
    "    ))\n",
    "else:\n",
    "    config.update(dict(\n",
    "        lr=0.0025,\n",
    "        bs=22,\n",
    "        epochs=2,\n",
    "        max_step=400000,\n",
    "        n_layers=12,\n",
    "        n_heads=8,\n",
    "        d_model=512,\n",
    "        d_head_inner=64,\n",
    "        d_ff_inner=2048,\n",
    "        dropout=0.1,\n",
    "        train_bptt=512,\n",
    "        eval_bptt=128,\n",
    "        mem_len=512,\n",
    "        eval_mem_len=2100,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing the Data Loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data loading for the Transformer XL is similar to data loading for an RNN-based language model but is quite different from standard data loading, so we'll go over it in detail.\n",
    "\n",
    "Suppose we chunked the input into sequences of 4 words to feed into the model. Remember that the Transformer XL is stateful, meaning the computations of each minibatch are carried over to the next minibatch. For a minibatch size of 1, handling this is simple. We just chunk the input and feed it into the model like this:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://i0.wp.com/mlexplained.com/wp-content/uploads/2019/06/Screen-Shot-2019-07-03-at-8.53.22-PM.png?w=1554&ssl=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, what happens if the `batch_size` is 2? We can't split the sentence like this, otherwise, we would be breaking the dependencies between segments."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://mlexplained.com/wp-content/uploads/2019/07/Screen-Shot-2019-07-03-at-8.56.15-PM-e1562212986605.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The correct way to handle the corpus with a `batch_size`of 2 is to feed it like this."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://mlexplained.com/wp-content/uploads/2019/07/Screen-Shot-2019-07-03-at-9.05.04-PM-1-e1562213341253.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generalizing this, we first divide the corpus into `batch_size` length segments, then feed each segment piece by piece into the model. Let's go through an example. Suppose `batch_size` is 4 and our entire corpus looks like this:\n",
    "\n",
    "    pytorch is an amazing deep learning framework that makes nlp really easy\n",
    "\n",
    "We want to make sure that the previous batch contains the previous segment at the same position.\n",
    "\n",
    "In other words, assuming we fed the model one word at a time, we want to iterate over this sentence like this\n",
    "\n",
    "    Batch 1: pytorch   amazing   framework nlp\n",
    "    Batch 2: is        deep      that      really\n",
    "    Batch 3: an        learning  makes     easy\n",
    "\n",
    "Notice that you can reconstruct the original sentence by reading from top to bottom, left to right\n",
    "instead of left to right, top to bottom.\n",
    "\n",
    "In reality, we feed the model with a sequence of words for each batch. The length of this sequence is commonly referred to the bptt (back propagation through time) length, since this is the maximum length the gradients propagate through in the sequence direction. With a longer bptt length of 2 for example, the\n",
    "minibatch would be of shape (batch_size, bptt) and would look like\n",
    "\n",
    "    Batch 1: pytorch   amazing   framework nlp\n",
    "             is        deep      that      really\n",
    "    Batch 2: an        learning  makes     easy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can implement this in a dataloader like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils import data\n",
    "import math\n",
    "\n",
    "class LMDataLoader(data.DataLoader):\n",
    "    def __init__(self, data: torch.LongTensor, batch_size: int, bptt: int,\n",
    "                 device=torch.device(\"cpu\")):\n",
    "        self.batch_size = batch_size\n",
    "        self.bptt = bptt\n",
    "        self.n_steps = data.size(0) // batch_size\n",
    "        \n",
    "        # we reshape the data here so that we can index\n",
    "        # efficiently into it while training\n",
    "        self.data = (data[:self.n_steps * batch_size] # trim off any elements that don't fit cleanly\n",
    "                     .view(batch_size, self.n_steps) # \n",
    "                     .transpose(0, 1) # \n",
    "                     .contiguous().to(device) # put on device as contiguous tensor\n",
    "                     )\n",
    "    \n",
    "    def __iter__(self):\n",
    "        for batch_start_idx in range(0, self.data.size(0) - 1, self.bptt):\n",
    "            batch_end_idx = min(batch_start_idx + self.bptt, self.data.size(0) - 1)\n",
    "            # TODO: What is `self.ext_len` in the original code?\n",
    "            batch_data = self.data[batch_start_idx:batch_end_idx]\n",
    "            target = self.data[batch_start_idx+1:batch_end_idx+1]\n",
    "            # we generate the sequence length as well for loss calculation later\n",
    "            yield batch_data, target, batch_end_idx - batch_start_idx\n",
    "    \n",
    "    def __len__(self):\n",
    "        return math.ceil(self.data.size(0) / self.bptt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's test this out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_corpus = torch.arange(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "BS = 16\n",
    "BPTT = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_corpus[:BPTT]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = LMDataLoader(test_corpus, BS, BPTT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1, *_ = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10, 16])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  0,  62, 124, 186, 248, 310, 372, 434, 496, 558, 620, 682, 744, 806,\n",
       "         868, 930],\n",
       "        [  1,  63, 125, 187, 249, 311, 373, 435, 497, 559, 621, 683, 745, 807,\n",
       "         869, 931],\n",
       "        [  2,  64, 126, 188, 250, 312, 374, 436, 498, 560, 622, 684, 746, 808,\n",
       "         870, 932],\n",
       "        [  3,  65, 127, 189, 251, 313, 375, 437, 499, 561, 623, 685, 747, 809,\n",
       "         871, 933],\n",
       "        [  4,  66, 128, 190, 252, 314, 376, 438, 500, 562, 624, 686, 748, 810,\n",
       "         872, 934],\n",
       "        [  5,  67, 129, 191, 253, 315, 377, 439, 501, 563, 625, 687, 749, 811,\n",
       "         873, 935],\n",
       "        [  6,  68, 130, 192, 254, 316, 378, 440, 502, 564, 626, 688, 750, 812,\n",
       "         874, 936],\n",
       "        [  7,  69, 131, 193, 255, 317, 379, 441, 503, 565, 627, 689, 751, 813,\n",
       "         875, 937],\n",
       "        [  8,  70, 132, 194, 256, 318, 380, 442, 504, 566, 628, 690, 752, 814,\n",
       "         876, 938],\n",
       "        [  9,  71, 133, 195, 257, 319, 381, 443, 505, 567, 629, 691, 753, 815,\n",
       "         877, 939]])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "b1, b2, sl = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the actual data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using the penn treebank dataset to benchmark our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "DATASET = \"penn\"\n",
    "DATA_DIR = Path(\"../data\") / DATASET"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using a utility vocabulary class borrowed directly from the Transformer XL repo to numericalize our inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append(\"../utils\")\n",
    "from vocabulary import Vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocab(special=[\"<eos>\"], lower_case=True)\n",
    "\n",
    "vocab.count_file(DATA_DIR / \"train.txt\")\n",
    "vocab.count_file(DATA_DIR / \"valid.txt\")\n",
    "vocab.count_file(DATA_DIR / \"test.txt\")\n",
    "None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "building vocab with min_freq=0, max_size=None\n",
      "final vocab size 10000 from 9999 unique tokens\n"
     ]
    }
   ],
   "source": [
    "vocab.build_vocab()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = vocab.encode_file(DATA_DIR / \"train.txt\", ordered=True, add_eos=True)\n",
    "valid_dataset = vocab.encode_file(DATA_DIR / \"valid.txt\", ordered=True, add_eos=True)\n",
    "test_dataset = vocab.encode_file(DATA_DIR / \"test.txt\", ordered=True, add_eos=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6503, 6151, 7924, 8539, 2353, 8540, 6918, 8541, 8542, 7394, 7925, 7926,\n",
       "        6152, 8543, 6504, 6919, 6920, 8544, 5560, 6153, 8545, 8546, 8547, 7927,\n",
       "           0, 9231,    2,    3,   73,  399,   34, 2136,    1,  146,   19,    6,\n",
       "        9232,  282,  450,    3,    0,   23,    2,   13,  142,    4,    2, 5090,\n",
       "           1, 2952])"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[:50]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can prepare the data loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iter = LMDataLoader(train_dataset, config.bs, config.train_bptt, device=device)\n",
    "valid_iter = LMDataLoader(valid_dataset, config.bs, config.eval_bptt, device=device)\n",
    "test_iter = LMDataLoader(test_dataset, config.bs, config.eval_bptt, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[6503,    6,    1,    0,    5, 5485,  542,    2],\n",
       "         [6151,    2,  327,  909,    6,    7,   16,    9],\n",
       "         [7924,  225, 6798,    2, 2069,  410,  880,  157],\n",
       "         [8539,   72,  329,  380,    1,   15,    3,  299],\n",
       "         [2353,    0, 1538, 1635,  104,  495,    3,  232],\n",
       "         [8540,   29, 9744,   32,   90, 5086,    0,    6],\n",
       "         [6918,   84,   60,    2, 2392,    0,   14,    2],\n",
       "         [8541,   27, 1437,  506,    1,  112,   24,    2],\n",
       "         [8542, 2298,  105,  557,  320,  553,  195, 3524],\n",
       "         [7394,    0, 2262,   16, 1715,  497,  150,    0],\n",
       "         [7925,  495,    2,    1,    4,   68, 2440,   64],\n",
       "         [7926,    2,  885, 2978, 1046,  112,   11,   47],\n",
       "         [6152,   45,    2,    0,    8,  918,    1, 3464],\n",
       "         [8543, 1973, 5616,   39,  723,   25,  490,   14],\n",
       "         [6504,    2,    4,   13,    1, 3553,   75,  568],\n",
       "         [6919,    6,  630,  784,  284,    5, 2369,    1],\n",
       "         [6920, 2664,    0, 6171,  819,   25,    1,  569],\n",
       "         [8544,   10,   29,   20,    0,    1,    2,    4],\n",
       "         [5560,    2,   39, 2841, 8138,    2,    8,    1],\n",
       "         [6153,   21,    2,  310, 2066,   36,    2,  372],\n",
       "         [8545, 4648,    1,    3,  573,    1,    4,  160],\n",
       "         [8546,    2,   53,   18,   33, 5177, 2947,   15],\n",
       "         [8547,    0,  779,  351, 2753,   54, 1118,    2],\n",
       "         [7927,  215, 1099,  682,   18,    0,    9,    2],\n",
       "         [   0, 1193, 7123,    4,   12,   64,    2,    2],\n",
       "         [9231, 1135,    1,  156,    3,   47,    8, 1683],\n",
       "         [   2,    7, 5232,  996,   49, 5177,    4,    4],\n",
       "         [   3, 1573,    4, 3098,  151,    0,   51,  157],\n",
       "         [  73,    2,    2,    2,    3,  787,    2,  299],\n",
       "         [ 399, 4448,    0,    5,  122,  674, 4057,  232],\n",
       "         [  34,    0,   23,   32, 1226,    2,   21,    0],\n",
       "         [2136,    2, 1575,    2,  361,   72,    6,   14],\n",
       "         [   1,  586,  329, 4467,  573,    4, 2561,    9]]),\n",
       " tensor([[6151,    2,  327,  909,    6,    7,   16,    9],\n",
       "         [7924,  225, 6798,    2, 2069,  410,  880,  157],\n",
       "         [8539,   72,  329,  380,    1,   15,    3,  299],\n",
       "         [2353,    0, 1538, 1635,  104,  495,    3,  232],\n",
       "         [8540,   29, 9744,   32,   90, 5086,    0,    6],\n",
       "         [6918,   84,   60,    2, 2392,    0,   14,    2],\n",
       "         [8541,   27, 1437,  506,    1,  112,   24,    2],\n",
       "         [8542, 2298,  105,  557,  320,  553,  195, 3524],\n",
       "         [7394,    0, 2262,   16, 1715,  497,  150,    0],\n",
       "         [7925,  495,    2,    1,    4,   68, 2440,   64],\n",
       "         [7926,    2,  885, 2978, 1046,  112,   11,   47],\n",
       "         [6152,   45,    2,    0,    8,  918,    1, 3464],\n",
       "         [8543, 1973, 5616,   39,  723,   25,  490,   14],\n",
       "         [6504,    2,    4,   13,    1, 3553,   75,  568],\n",
       "         [6919,    6,  630,  784,  284,    5, 2369,    1],\n",
       "         [6920, 2664,    0, 6171,  819,   25,    1,  569],\n",
       "         [8544,   10,   29,   20,    0,    1,    2,    4],\n",
       "         [5560,    2,   39, 2841, 8138,    2,    8,    1],\n",
       "         [6153,   21,    2,  310, 2066,   36,    2,  372],\n",
       "         [8545, 4648,    1,    3,  573,    1,    4,  160],\n",
       "         [8546,    2,   53,   18,   33, 5177, 2947,   15],\n",
       "         [8547,    0,  779,  351, 2753,   54, 1118,    2],\n",
       "         [7927,  215, 1099,  682,   18,    0,    9,    2],\n",
       "         [   0, 1193, 7123,    4,   12,   64,    2,    2],\n",
       "         [9231, 1135,    1,  156,    3,   47,    8, 1683],\n",
       "         [   2,    7, 5232,  996,   49, 5177,    4,    4],\n",
       "         [   3, 1573,    4, 3098,  151,    0,   51,  157],\n",
       "         [  73,    2,    2,    2,    3,  787,    2,  299],\n",
       "         [ 399, 4448,    0,    5,  122,  674, 4057,  232],\n",
       "         [  34,    0,   23,   32, 1226,    2,   21,    0],\n",
       "         [2136,    2, 1575,    2,  361,   72,    6,   14],\n",
       "         [   1,  586,  329, 4467,  573,    4, 2561,    9],\n",
       "         [ 146,  158,   63,   17,   33,    1,    4,    1]]),\n",
       " 33)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(train_iter))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Initialization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We borrow the following initialization from the Transformer XL repo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_weight(weight):\n",
    "    nn.init.normal_(weight, 0.0, 0.02)\n",
    "\n",
    "def init_bias(bias):\n",
    "    nn.init.constant_(bias, 0.0)\n",
    "    \n",
    "# Borrowed from the transformer XL repo\n",
    "def weights_init(m):\n",
    "    classname = m.__class__.__name__\n",
    "    if classname.find('Linear') != -1:\n",
    "        if hasattr(m, 'weight') and m.weight is not None:\n",
    "            init_weight(m.weight)\n",
    "        if hasattr(m, 'bias') and m.bias is not None:\n",
    "            init_bias(m.bias)\n",
    "    elif classname.find('Embedding') != -1:\n",
    "        if hasattr(m, 'weight'):\n",
    "            init_weight(m.weight)\n",
    "    elif classname.find('LayerNorm') != -1:\n",
    "        if hasattr(m, 'weight'):\n",
    "            nn.init.normal_(m.weight, 1.0, 0.02)\n",
    "        if hasattr(m, 'bias') and m.bias is not None:\n",
    "            init_bias(m.bias)\n",
    "    else:\n",
    "        if hasattr(m, 'u'):\n",
    "            init_weight(m.u)\n",
    "        if hasattr(m, 'v'):\n",
    "            init_weight(m.v)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "No fancy initialization here. Since we have multiple Layer Normalization layers, we can get away with initializing everything using a simple normal distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training Loop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training loop is fairly standard. You can use any framework you like here including [ignite](https://github.com/pytorch/ignite), [allennlp](https://github.com/allenai/allennlp), and [fastai](https://github.com/fastai/fastai). We'll be writing our own loop to simplify things."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import time\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "\n",
    "loss_change = []\n",
    "val_loss_change = []\n",
    "\n",
    "def train_epoch(\n",
    "    epoch: int,\n",
    "    model: nn.Module, train_loader: data.DataLoader, \n",
    "    val_loader: data.DataLoader,\n",
    "    optimizer: optim.Optimizer,\n",
    "    scheduler,\n",
    "    train_step_start=0.,\n",
    " ):\n",
    "    # Turn on training mode which enables dropout.\n",
    "    model.train()\n",
    "    mems = None\n",
    "    train_step = train_step_start\n",
    "    train_loss = 0\n",
    "    log_start_time = time.time()\n",
    "    best_val_loss = float(\"inf\")\n",
    "    \n",
    "    pbar = tqdm(train_loader, total=min(config.max_step - train_step_start, len(train_loader)))\n",
    "    for batch_idx, (data, target, seq_len) in enumerate(pbar):\n",
    "        model.zero_grad()\n",
    "        out_dict = model(data, target, memory=mems)\n",
    "        loss, mems = out_dict[\"loss\"], out_dict[\"memory\"]\n",
    "\n",
    "        loss.backward()\n",
    "        train_loss += loss.item()\n",
    "        loss_change.append(loss.item())\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip)\n",
    "        optimizer.step()\n",
    "        \n",
    "        # step-wise learning rate annealing\n",
    "        train_step += 1\n",
    "        # linear warmup stage\n",
    "        if train_step < config.warmup_step:\n",
    "            curr_lr = config.lr * train_step / config.warmup_step\n",
    "            optimizer.param_groups[0]['lr'] = curr_lr\n",
    "        else:\n",
    "            scheduler.step(train_step)\n",
    "            \n",
    "        if train_step % config.log_interval == 0:\n",
    "            cur_loss = train_loss / config.log_interval\n",
    "            elapsed = time.time() - log_start_time\n",
    "            log_str = '| epoch {:3d} step {:>8d} | lr {:.3g} ' \\\n",
    "                      '| loss {:5.2f}'.format(\n",
    "                epoch, train_step, optimizer.param_groups[0]['lr'], cur_loss)\n",
    "            log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))\n",
    "            pbar.set_description(log_str)\n",
    "            train_loss = 0\n",
    "            log_start_time = time.time()\n",
    "\n",
    "        if train_step % config.eval_interval == 0:\n",
    "            val_loss = evaluate(model, val_loader)\n",
    "            val_loss_change.append(val_loss)\n",
    "            eval_start_time = time.time()\n",
    "\n",
    "        if train_step == config.max_step:\n",
    "            return train_step\n",
    "    return train_step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_loader, valid_loader):\n",
    "    optimizer = optim.Adam(model.parameters(), lr=config.lr)\n",
    "    total_steps = min(config.max_step, len(train_loader) * config.epochs)\n",
    "    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,\n",
    "                    total_steps, eta_min=config.min_lr)\n",
    "    train_step_start = 0\n",
    "    for epoch in range(config.epochs):\n",
    "        if train_step_start >= config.max_step:\n",
    "            break\n",
    "        train_step_start = train_epoch(\n",
    "            epoch,\n",
    "            model,\n",
    "            train_iter,\n",
    "            valid_iter,\n",
    "            optimizer,\n",
    "            scheduler,\n",
    "            train_step_start,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Language models are normally evaluated by perplexity. Perplexity is the exponential of the cross entropy loss. It is also equivalent to the reciprocal of the likelihood. If the language model assigns a probability of 0.1 to each word in each input sentence on average, it would receive a perplexity of 100.\n",
    "\n",
    "Intuitively, perplexity represents how many tries it would take for the model to guess the correct word. A perplexity of 100 would signify that the model would need 100 tries to guess each word in the input sequence correctly.\n",
    "\n",
    "Keeping this in mind, we can write the evaluation code like this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model: nn.Module, val_loader: data.DataLoader):\n",
    "    # Turn on evaluation mode which disables dropout.\n",
    "    model.eval()\n",
    "    model.reset_length(config.eval_bptt,\n",
    "        0, config.eval_mem_len+config.train_bptt-config.eval_bptt)\n",
    "\n",
    "    # Evaluation\n",
    "    total_len, total_loss = 0, 0.\n",
    "    with torch.no_grad():\n",
    "        mems = None\n",
    "        for i, (data, target, seq_len) in enumerate(val_loader):\n",
    "            out_dict = model(data, target, memory=mems)\n",
    "            loss, mems = out_dict[\"loss\"], out_dict[\"memory\"]\n",
    "            total_loss += seq_len * loss.float().item()\n",
    "            total_len += seq_len\n",
    "\n",
    "    # Switch back to the training mode\n",
    "    model.reset_length(config.train_bptt, 0, config.mem_len)\n",
    "    model.train()\n",
    "    return total_loss / total_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_final(model, val_loader):\n",
    "    model.eval()\n",
    "    total_len, total_loss = 0, 0.\n",
    "    start_time = time.time()\n",
    "    \n",
    "    model.reset_length(config.eval_bptt, 0, config.eval_mem_len + config.train_bptt - config.eval_bptt)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        mems = None\n",
    "        for i, (data, target, seq_len) in enumerate(val_loader):\n",
    "            out_dict = model(data, target, memory=mems)\n",
    "            loss, mems = out_dict[\"loss\"], out_dict[\"memory\"]\n",
    "            total_loss += seq_len * loss.item()\n",
    "            total_len += seq_len\n",
    "        total_time = time.time() - start_time\n",
    "    \n",
    "    model.reset_length(config.train_bptt, 0, config.mem_len)\n",
    "    loss_val = total_loss / total_len\n",
    "    return {\"loss\": loss_val, \"ppl\": math.exp(loss_val)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now all we have to do is initialize the model and start training it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_xl = TransformerXL(\n",
    "    num_embeddings=len(vocab), n_layers=config.n_layers,\n",
    "    n_heads=config.n_heads, d_model=config.d_model,\n",
    "    d_head_inner=config.d_head_inner, \n",
    "    d_ff_inner=config.d_ff_inner,\n",
    "    dropout=config.dropout,\n",
    "    dropouta=config.dropouta,\n",
    "    seq_len=config.train_bptt,\n",
    "    mem_len=config.mem_len,\n",
    ")\n",
    "if torch.cuda.is_available(): transformer_xl.cuda()\n",
    "transformer_xl.apply(weights_init);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "| epoch   0 step     3400 | lr 0.000132 | loss  6.22 | ppl   502.650: 100%|██████████| 3522/3522 [14:41<00:00,  4.00it/s]  \n",
      "| epoch   1 step     7000 | lr 2.41e-08 | loss  6.14 | ppl   462.692: 100%|██████████| 3522/3522 [14:58<00:00,  3.92it/s]  \n"
     ]
    }
   ],
   "source": [
    "train(\n",
    "    transformer_xl,\n",
    "    train_iter,\n",
    "    valid_iter,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 6.067088326684745, 'ppl': 431.42268910022347}"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate_final(transformer_xl, valid_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's observe the loss change."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x12e7e42e8>]"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd4FOX2B/DvSSP0GqSTUAQBpYUqIoqggIrXiuWqiAI2LNefonJtl3vBa++IiF1EVJArvYiAQiBA6C2QQOihJZCQfn5/7MxmdnZ2d3YzW3M+z5MnuzOzsyfJ5sw7byVmhhBCiMgSFewAhBBCWE+SuxBCRCBJ7kIIEYEkuQshRASS5C6EEBFIkrsQQkQgSe5CCBGBJLkLIUQEkuQuhBARKCZYb9ygQQNOTEwM1tsLIURY2rBhw0lmTvB0XNCSe2JiIlJTU4P19kIIEZaI6ICZ46RaRgghIpAkdyGEiECS3IUQIgJJchdCiAgkyV0IISKQJHchhIhAktyFECIChV1yP3QmH61fmI/cguJghyKEECEr7JL79NWZKC1jXPbK4mCHIoQQISvskvtTg9raH2edzg9iJEIIEbrCLrnXjI9F83pVAQD/mLU5yNEIIURoCrvkDgA3dWkKAFiXcTrIkQghRGgKy+T+xMC2ng8SQohKLCyTe0x0edh7jp8LYiRCCBGawjK5a+0+JsldCCH0TCV3InqCiLYR0XYietJgPxHR+0SUTkRbiKib9aEa+zXtSKDeSgghwobH5E5EnQA8BKAngM4ArieiNrrDhgBoq3yNBvCJxXE6GdKpEQAgoWYVf7+VEEKEHTMl90sApDBzPjOXAPgDwM26Y4YD+Jpt1gKoQ0SNLY7VwTPXtgMA9Eyq68+3EUKIsGQmuW8DcAUR1SeiagCGAmiuO6YpgCzN80PKNr+pHmdbIfDU+SJ/vo0QQoQlj8mdmXcCeB3AYgALAaQBKPXlzYhoNBGlElFqdna2L6ewqxoXDQCYOG9nhc4jhBCRyFSDKjN/zszdmbk/gDMA9ugOOQzH0nwzZZv+PFOZOZmZkxMSPC7e7VbV2OgKvV4IISKZ2d4yDZXvLWCrb/9ed8hcAPcqvWZ6A8hh5qOWRqoTG03+PL0QQoS1GJPH/UxE9QEUA3iUmc8S0VgAYOYpAObDVhefDiAfwEh/BKtFJMldCCFcMZXcmfkKg21TNI8ZwKMWxmXKFW0b4HxhSaDfVgghQl5Yj1CNi45CcWlZsMMQQoiQE9bJPSaaUFzCwQ5DCCFCjtk695C0aPvxYIcghBAhKaxL7ipblb8QQghVRCT34lJJ7kIIoRXWyf3vvVsCADYdPBPkSIQQIrSEdXL/Zu0BAMAdU9cGORIhhAgtYZ3chRBCGJPkLoQQESisk/sNnZsEOwQhhAhJYZ3cuzSvE+wQhBAiJIV1cpf+7UIIYSysk3uvpPrBDkEIIUJSWCf3S5vVDnYIQggRksI6uQshhDBmdiWmp4hoOxFtI6IZRBSv238/EWUTUZry9aB/whVCCGGGx+RORE0BjAOQzMydAEQDGGFw6Exm7qJ8TbM4To+kcVUIIcqZrZaJAVCViGIAVANwxH8h+WZdxulghyCEECHDY3Jn5sMA3gRwEMBR2Ba/Xmxw6C1EtIWIfiKi5hbH6dGF4tJAv6UQQoQsM9UydQEMB5AEoAmA6kR0j+6w/wFIZObLACwB8JWLc40molQiSs3Ozq5Y5IqXb+gAAKgZH9brjgghhKXMVMtcAyCDmbOZuRjALwD6ag9g5lPMXKg8nQagu9GJmHkqMyczc3JCQkJF4ra7pHEtAEBBsaylKoQQKjPJ/SCA3kRUjYgIwEAAO7UHEFFjzdMb9fv9KT42GgBQINUyQghh57Eug5lTiOgnABsBlADYBGAqEb0GIJWZ5wIYR0Q3KvtPA7jffyE7io+1XZ+k5C6EEOVMVVQz88sAXtZtfkmz/3kAz1sYl2lVlZK7NKgKIUS5sB+hWjXOltzTsmSpPSGEUIV9cq8SbUvu3649GORIhBAidIR9co+LCfsfQQghLBf2mVGb3EtKpVFVCCGACEju0VFkf5x15kIQIxFCiNAR9sld641Fu4IdghBChISISu7ztx4LdghCCBESIiq5CyGEsImI5K7OLyOEEMImIpL733u3DHYIQggRUiIiuf+ta9NghyCEECElIpK7DGQSQghHEZEVtX3dtx3OCWIkQggRGiIiuWtd/8HqYIcghBBBF3HJXQghhMnkTkRPEdF2ItpGRDOIKF63vwoRzSSidCJKIaJEfwQrhBDCHDMLZDcFMA5AMjN3AhANYITusFEAzjBzGwDvAHjd6kCFEEKYZ7ZaJgZAVSKKAVANwBHd/uEAvlIe/wRgoLLeqhBCiCDwmNyZ+TCAN2FbKPsogBxmXqw7rCmALOX4EgA5AOpbG6oQQgizzFTL1IWtZJ4EoAmA6kR0jy9vRkSjiSiViFKzs7N9OYVLk2++1NLzCSFEODNTLXMNgAxmzmbmYgC/AOirO+YwgOYAoFTd1AZwSn8iZp7KzMnMnJyQkFCxyHVqV421Pz5XUGzpuYUQItyYSe4HAfQmompKPfpAADt1x8wFcJ/y+FYAy5mZrQvTM20V//HcAhQUlwby7YUQIqSYqXNPga2RdCOArcprphLRa0R0o3LY5wDqE1E6gKcBjPdTvC71blXP/viat1di1FfrAx2CEEKEjBgzBzHzywBe1m1+SbO/AMBtFsbltTrV4hye/5nuVCskhBCVhoxQFUKICCTJXQghIlBEJffGteM9HySEEJVARCX3Gzs3CXYIQggREiIquf9jcLtghyCEECEhopK7rMgkhBA2kg2FECICSXIXQogIFNHJvbQsoDMgCCFEyIi45P7QFUn2x5MX6KfAEUKIyiHikvt9fRPtjz9blRG8QIQQIogiLrk3q1st2CEIIUTQRVxyF0IIIcldCCEikiR3IYSIQGbWUG1HRGmar1wielJ3zAAiytEc85Kr8wkhhPA/j4t1MPNuAF0AgIiiYVsvdbbBoauY+Xprw6u44tIyxEbLDYoQonLxNusNBLCPmQ/4IxirfDmyh/1x2xcXBDESIYQIDm+T+wgAM1zs60NEm4loARF1rGBcFTKgXUOH5wFeq1sIIYLOdHInojgANwKYZbB7I4CWzNwZwAcA5rg4x2giSiWi1OzsbF/i9cnXa0L6RkMIISznTcl9CICNzHxcv4OZc5n5vPJ4PoBYImpgcNxUZk5m5uSEhASfg/ZWWtbZgL2XEEKEAm+S+51wUSVDRI2IiJTHPZXznqp4eNZYsftEsEMQQoiA8thbBgCIqDqAQQDGaLaNBQBmngLgVgAPE1EJgAsARnAIVXSfyS8OdghCCBFQppI7M+cBqK/bNkXz+EMAH1obmhBCCF9FbAfwBjWqODyft+VokCIRQojAi9jk3rK+4+yQj36/MUiRCCFE4EVscq9TNTbYIQghRNBEbHIXQojKLGKTe0w0BTsEIYQImohN7v+6qZPTtu1HcoIQiRBCBF7EJveGNeOdtu06ei4IkQghROBFbHI38tzPW4IdghBCBESlSu4lZSEzaFYIIfwqopP7mP6tnLZ98WdGECIRQojAiujkrsxl5uDV/+0IQiRCCBFYEZ3cGcbVMLuPScOqECKyRXRyd5HbkVsgs0QKISJbRCd3V82nt01ZgxPnCgIaixBCBFJEJ/ebujR1uS/r9IUARiKEEIEV0cm9Q5NauKFzE7+cO9BrkTAz5m05iuLSsoC+rxAiPHlM7kTUjojSNF+5RPSk7hgioveJKJ2IthBRN/+F7J0nr2lr+Tl/XJ+FpOfnI/tcoeXndmXZzhN49PuNeH/Z3oC9pxAifHlM7sy8m5m7MHMXAN0B5AOYrTtsCIC2ytdoAJ9YHaivWifUsPycc9IOAwB2Hcs1/ZqyMsaBU3k+v+fp/CIAwNEcaSsQQnjmbbXMQAD7mPmAbvtwAF+zzVoAdYiosSUR+smFolKfXxsdZes/P21VBkpNjnqdsnIfrnxjhXTDFEIEhLfJfQSAGQbbmwLI0jw/pGxzQESjiSiViFKzs7O9fGvfdW5W22nbPZ+nVPi8f+zJxqzULM8HAlifcRoAcOhMPibN34lth2WGSiGE/5hO7kQUB+BGALN8fTNmnsrMycycnJCQ4OtpvGY0UhUAZq4/CAA4cCoP7SYswP7s8w77F2w9ioLiUmSezMNTM9MMGzPzvLwDKCopw6cr9+Pmj/8CABSXluGv9JNencMfjpy9gKzT+cEOQwhhEW9K7kMAbGTm4wb7DgNornneTNkWEqJcrNvx3M9bcSK3AC/M3orCkjLM3lQecmrmaTz83UZMnLcD//fTZszedBibDp71OQb9Baa4zHaheGvxHtw1LQUbDpzx+dyeTFu1H28v2eP2mL6Tl+OK//7utxiEEIHlTXK/E8ZVMgAwF8C9Sq+Z3gBymPlohaOziKuSOwDMWJeFP9NPAQA+WJ5uL83nXLCNYj1y1n0D5t7j50yVePVdJ9Wn6Sdsdwunzvuv583EeTtDvpdNaRnjo9/Tcb6wJNihCBERTCV3IqoOYBCAXzTbxhLRWOXpfAD7AaQD+AzAIxbHWSGuSu4A8M5SxxLtcz9vBVCefLUvNerb/sP6LK9KvN7OOrzx4Bkkjp+Hw2cie9DV/K1H8cai3Xh9wS5Tx3+8Ih2J4+dZ3u9/59FcvL5wV8DHMQhhNVPJnZnzmLk+M+dotk1h5inKY2bmR5m5NTNfysyp/grYF+5K7q6o/9rp2eftiT63wFaqNPN/n3U6H51fXYyMk3kOMZRpXlxQXIqlO41qucr9sM52J7FaqZf3R84x2+PHnwpLbEk6r8hcyf2T3/cBAC4Um2vz2HY4B28v3u3xuNs/XYNPVuyTOwgR9iJ6hKrKXcndyJp9p+yPD5zKR6pSH/7Q166vWcdyCrD1UHkPmPlbjyLnQjHGzdiEnPzyicrUJAYAszYcsj82Sq8p+0/hx1TbMd5enrYfyTE9f84dn67x8uzWq2hJ+URuAX7S/D71bvhwNd5fnu7xPGXKhc6XAoEQoaRSJPebuzbz6vgHvlyPZR5K1Hq9Jy3DDR+utj/foiT6rYdz0Pm1xVi+6wQA4JlZm+3HaBPamG82AAA+XL4XiePngZlxx9S19v2n84q8imfY+6vR89/LTB2bqmvMXbkn29724A+r957EPl3PJBWZvIzpLwX3f7Eez8za7PL3pP6qzV5EpFpGhLuYYAcQCLf3aI7s84V4Y5Hn23LAdqv/w3pz/dddmbfVc3vyiVznRtQ3Fxv3atmvVO+4mqPeSvdOXwcAaFgrHle1a2jZeZkZ36w9gJd+3Q4AyJw8rHxfBc+drTRIl3hRB79w2zE0q1sVnZo6j4MQwh+YGTuO5qJjE/9/5ipFyR0AHr2qjSXnmTR/p9elaFc+/N11NYHLgqNu+4WiUtw9bS1un7IGN3/8J/Zln0fmSffTHIz/eQsufXmRx/hGfrHe4zHeSMk4bU/srpitDVHrxAuUcQbqyzw1H2h/r2O/3YDrP1jtuF/5fiZP5vwX1vtl42EMe381luzwrmbAF5UmuVvl05X7seOo+TllvDFnk/dDA1IyTuHP9FNYl3kaGw+excC3/sCAN1cYHqvWJ/+wPgvnAtBgWFrGKCwpb/AsKrF+Rss/9thGOkcZNFgbMdp7Oq/I3vU1X7lYPPvzZoMjI9u3aw/Yu+YK75zILcDor1M9NsTvPm6bfkQ/YNIfJLlb7LctR/DsT74lhidnptkfpx0yHjClT06eqjPKNEVZRmB7xtz/xTq0m7DQ/jxKVyyfvekQjqkToSlhHc25gMTx85CWZW7AmNrwqc73oyb3TQfPoMBkT5pu/1qCrq8tdtiWX4G5h8LVhDnbMOz9VcEOIyTlXCjG4zM22QsBeu8s3YvFO47j1zRzBbRA/BdKcrfYY99vsvdwqYjUzNOG2w+dsQ2YKiktw4WiUuw44v4uYtaG8raDMmb8d5FjP3Jv6qhVx3MLHErkrqza6zitgr7X0lMzN+OuabZG4wXbbG0U6oCyH9YdxOMzNmH4h47VJp4w2y4Qf/v4Lzz/y1aD/cb/Vt5c8xZvP4ZF2495FVewMDOmr85wmZT0Cn24u1q+6zhGu+lJFkiJ4+dh4m87KnyerNP5uPKN3+2Fj+mrM/C/zUcwfXWG29eFUju8JPcQ9Z/5xoN51meewV/pJzHmmw245KWFHhuJ1UFZAPDZqv1YtK08KeUVlmCahw+r6t/zduDH9VkoK2P0+s8yPKW5y9DbdPCMQzJJHD8P36UcQJRBn9T92XlIHD8Pv+92nkjuf5uPYPMhcxOsRSmf5KLSMuQpt8ZbXNz9uGJ2lPDobzbYezf5AzNj0fZjPl14Adt8RergrjX7T+G133ZgwpxtVobo4IEvU7E4AHXIZk1bnYE9xys2++q3KQdw4FS+fUoS9a7TVeEgFHvOSnIPQ3dNS8EypWulN/67cLdD/+1pqzKceuwYVYdsP5KDz1Zl4Nmft9irPRZtN/5nLi4tw98+/gujvnRsjJ2+OsOpWsZX2qSnnlE998C3/sC2w7a7GaP/Q3cFq4e/22h/rHZlPXGuwO2t9t7j57B813FL66qX7TyBMd9swEfKQC0z8otKcPuUNdh7/Bw6vrwIfSbZusGqJfFckyV3f1m7/xRemeu+Md1Kg99Zaen51I9uRWs1A3kNqFTJ/b0RXYIdQtBpSx4lZWVO1QtLdjg+n7/1KIa9X141on64S8vY/tqcC8V4Ze52FJaU2uv0t+imNN6XnYeMkxVPgFP+2Ic7Pyvv/79fOeeBU+Xz+9hH82pep/5zfr3mAApLSnE237nH04lc50Ff901fjyd+SDM8HgAGvbMSD3yZimve/sPbH8Ul9a4n04vFXf5SGtUnLdiFopIynDzvXY8uK/r1uzvHiKlr8eVfmRV+j4p4b+lebDbZlqN+eNQR0FH25G5NvUsgqm8qVXIf7mbB7MoiU5MEo4hw+KzjnDX6Erm+RKrtZ69WTby1eDe+/CsT7SYsxGQ3c8Noq4i8lV9UgvQT5zB5wS6szywfdGVUulVHqmqTjVpi+tdvO3D/9PXo8toSj++55/g57FR6RgWyITo+NhoATDcIAxWvFrAi2XhzjtIyRm6Bf+8mMk/m4bjmgv3O0j0Y/tGfpl6rzgCrTrgXZW+wd/86T7+CQwGcI6pSJXfhyKikpU/m+pKK0T9wcWn5RrV0VtFuj9pBZO0mLMDIL9bjmre9u9V29Y+2Zv8pF3scaW/tzU5HUFJahjcW7cKp84V4b+lerxK0Klr5r/x99wmc83MCVFlRIlXPUFxaVt4LCsaN9hPn7cBlryyu0Ipongx4cwV6/cfcKG09/YVHHTntss7dxDlX7c02NbjRKpUuuf/z+g7BDiFkmJlrRb9i1Du6eeFX7z2JlAxzydJXhSVlSMkw7j2katPQea1cbxOW9q5Gz2zBeN7Wo/jo933oPnEp3lm6B9NW7bfvO3W+ECe9mNq5oLjMbcO1lay4L1ET3ytzt6P3pGX2BDlx3k6nY9W7q4oUAsrKGEdzvC8J788+j3yTE9Sp9NUyPf+9FM//ssXta+6bvg7JE5cgv6gEJ88XYtfRwC6xWemS+6h+SYjxdiaxSmzpTseG209X7nd4fs/nKdif7fvC31a5qFYVp21qbs86nV/hhjCVdopho3n8tXcxgGPXwu4TlyJ54lLD87qq0/f377awpBQT5mzFGS9HXY/5JhU3f2xcxaHOo3RemUV1xe7yzxAz40xeEc4p+yrSwvjusr3oM2m5vXuwK5kn85CiuVu7+q0/8OBXxl03s88VOkz0p4rWVcucOFeIGevcT1Hyx55snDxfhJs++tPl392fKl1yB4DZj1we7BCEhZ6emWbvH6+lJnej/u7eUmtlFmi6kpqZx99M7tpxJBddXltiej1erWM5BQ6JebmHXlTq+gAHlMbaj5an49u1B9HTy+qLRduPY6NuZTL1slY+FYRti/aOqIyB0d+UJ1a1tP/i7K1eD6BaqYxOPm4wR5PWgDdXOEzCBwB/7TuFM3lFTktc9vj3UnSfuAS7XCxkn3Oh2GFgoN6OI85dd/cct1V1Brq7pNnFOuoQ0U9EtIuIdhJRH93+AUSUQ0RpytdL/gnXGg0NSnkifP3iYtqG0jIGM1vyT0UgPPb9RoybscnlMbkFxU63+0eUuuc8N8PS1T7Z6qAvh0FwbmLPKyxB70nLkPzvpaZ/RrU6ZKXyXkZVc/p65V3Hcl1WY2gbmktKHadLLjOocWFmh2Svvv67lIPY7mFAnmru5iP2aSds72fqZU7+Pj0Fd01LcVrwpUSXvHPyi+1rOfy04RAe0XSZ1ZuxLsupKtOImUGAFWW25P4egIXM3B5AZwDOlWjAKmbuony9ZlmEfpBQQ5J7ZXAstwBJz8+35Fzp2efx2xb3jWGXvbLYaWK0nzYcwrqM0+joYqK2EVPXOKwGdiynwKH07SpvFRSX2s9ppiePvv3BXT78NqV8uufCklJc9+4qPKpJaNqJ6bq8Wj5tw88bbReOKDdZRR+pL9Vl42Zswn3T19nP5cucTAAc6sD3ZZ932ZPm8NkLSKhZnjMWehidfOSs53aAd5f6f9lLj8mdiGoD6A/gcwBg5iJm9n2l6BBgNFJSRC4rFt7Q9v7w1u1uFkNZu/+0vY/+pqwzKDEq7hrQN2wbzYO/+9g5exXCqr0nMejtPxx6Ow1y0Tf/t81H7I/VC0dKxmks33UciePnOUxMp52AbsKcbTidV2R/j1N5ztUlZczIPlfo8Nwo7gNm+vgrr/16zQHc/8U6rxqrAaBU895vLd7tsg88EVBN6Z7q8lyaq9TRnAL8048jgs0yU3JPApAN4Asi2kRE05Q1VfX6ENFmIlpARB2NTkREo4kolYhSs7Odh5sL4Q9WXMrNJl1v6Ks/sk5fcLoQZZ2+gFs/+QvnC0scSoT6hm2juW7OF5ZglKbhcO+J85ihLNv40q/bsNfFqNrCkjLkF5Xg5w2H7Ik6v6gUD3zpef6Y6asz7H25jVYuO57jmICNkvu1767ElW+scNg2Z9Nhp4ZO7StX7M72utFSu0i9u45VROUNqkYyTuY5dN19ee52fLP2gFex+IOZ5B4DoBuAT5i5K4A8AON1x2wE0JKZOwP4AMAcoxMx81RmTmbm5ISEhAqEXXGdm8kCDZWFqVKgB/4YfPKmwZqu+hxSVFqG1ANn0OnlReg7ebnLOwhvF5dxVx2SlnUWE2Zvwz9mbfa6m6u2H7fRKNli3UXSTJXS/uzzeHJmGp6Y6djeofbGqagh761y6uWkxez+bn/3MXNtBZ+sMD+dhBXMJPdDAA4xc4ry/CfYkr0dM+cy83nl8XwAsUTUwNJILfbrY/2CHYIIEHf9180yu4qXGd8qpTqj0bWe5t85eb7Q9GhZd6VNT7KU7oWfm5xYTlWoG7Sl7+KpLyG7KzGrdekFxbYLword2UgcP8++/4gPfdxdcbfecBkzot3+Xcz9nk9ZtMiPWR6TOzMfA5BFRO2UTQMBOMypSUSNSLmfJKKeynn9O7JFiDA1Yc42l5ORuVvkW2W2AfHuz9Z6PsgFdQI5oy6m7hzR3Vn8tU//esdsXlrGmLTAqH+GbX0DZnY5GM2qieiA8onijDA731FpheKMkID53jKPA/iOiLYA6ALgP0Q0lojGKvtvBbCNiDYDeB/ACA6DFYZXPDPA4fmIHs2DE4iodJ74wXjkqZk7hAMGg6eM5FVgaL+7agpv6LsN/pp2xOH56fwifPpHefuBtmQOGC+FqArUgiq5BcVuV1gK0dxuboFsZk4DkKzbPEWz/0MAH1oYV0AkNnBsF558y2UVXhhbCH9au/+UfTKrcPSBrl+9pyKgq6mlA+muz1Jc7isqKbOkN5Y/VMoRqlpJSoKf8VBvw/2/PiqjWUXoMJqnRQTPxRMWBGRAki9Mldwj2ZxHL8fZ/CK0rG/UuxPo2KQWGtWKxzGDub6FEBXz9I+BmRjNn0J1UfFKX3KvXTXWZWJX3dWrRYCiEaJyOWBBT6ZgC8RoU19U+uTuCRG5bSkXQohQJMndg+gowpBLGxvua1QrPsDRCCGEOZLcXfjn9R3QsUktAEDrhBr48K6uQY5ICCHMq/QNqq6M6peEUf2S7M+NBkwEqgdUgxpxXi94LISo3KTkbpLRUO537ugSkPd+YmDbgLyPECJySMld58Whl2D3cedVWGKjHZN7TBShd6v6AYmpodTtCyG8JCV3nYf6t8Kbt3V22h6tW4Fg2n22AbutEtx3o2xap6rLfWYmdqpfPc5e92/kfxZNgFanWqwl5xFChAZJ7ibV1SW/Ae0aAgB+HNMHk26+1PA1r9zQweH5W8pFo1GtePw5/mpse+VajLu6jeFrMyYNxRcje2D5MwNQo4rrG6xLLZq6uHqc3MQJEUkkuZt0WbM6+OTubk7bG9Sogms7NgLgeAHInDwM91+e5HBsfGw0vhzZA/PG9UPTOlVRNS4aTw9uByNEhKvaNUTtqrGoUy3ObWw3dWni7Y8jhAgifTWvP0hy94La371bizqmX6PtURMdZSvx16/AGq5jr2xdHk8n20XlwStaeX2eDo0dq3oa1bbdTQgh/K92Vf9Xg0py99LuidfhxzF9fHptxyYVr0IZP6S9/bF64WidUMPr8+gnSvtyZA+37QNCCOs8dpVxdayVJLl7qUpMNGKiXf/a+rVpgHc1XSS1Jffm9aoZvmbKPd19ikWdLrVqXDQyJw/DeyO6YGD7hvb9i57s7/K1+hXqa8ZLg6oQgeJu2T6rmGpFI6I6AKYB6ATbUioPMPMazX4C8B6AoQDyAdzPzBuNzhWJtOuSfPtgL4d9PRLrIeu0+5Vzqldxv7I6ADSrW9VhUBUAxOouMsO7NMXwLk3tCx7ExzpfhH4c0wdHzl6wdBUbIYR3AvHfZ7bk/h6AhczcHkBnAPpJpYcAaKt8jQbwiWURhjl9QvbV6ueuxkhdA62ni0ILgzuFnkn1cFPXpiGb3O/r09L0sZ2auu4iKkQoC8QCHx6TOxHVBtAfwOcAwMxFzHxWd9hwAF+zzVoAdYjIeLatCOTuD6Um0XYX1fTHOxtVKT+/AAAU1klEQVRuHd6lCRrVincbl6tdE4ZdYkVgLi37x5X4Xnd3o/Xq8E6mzzXokkYOz+tXd9+rCABWPXuV6fNb7f6+iUF7bxFaAlG2MlNyTwKQDeALItpERNOISD9ypykA7fp0h5RtlYK75WLVPyLD+iVlXX1A3hvRFWtfGOjTax+8ohU2vzzYb635rRNqoG+bBpac64F+iV6/xlW7RyCMvDwxaO8tQkuX5uZ73PnKTHKPAdANwCfM3BVAHoDxvrwZEY0molQiSs3OzvblFCHNqKRMfqxd8+bMrw3viB9Gl/eQcRdX7aqx+O1xa0a+ekPt+3t1+4aIi/H80YyPLa+Wctd47K2/dfVcLvGlFJ5Q0/cusCKytG3ojzt5R2aS+yEAh5hZXSX2J9iSvdZhAM01z5sp2xww81RmTmbm5ISEBF/iDWlGJXh7yd36grupW7u3b++M7x7shXv7JDrMhRPjobXeH/FqLXzyCjx7neMArisvtn0mpt/fA3+Z6HOv/QnaNarYP0sDzdgDMwNMXrmxo9fTQFeTUcBCERKDmJj5GIAsIlL/EwcC2KE7bC6Ae8mmN4AcZj5qbajhSf0TlvkhW7Zv5LlB8eZuzXC5QTWIp65Y2mqkYZf53nySOXkYMicPc9revlEtPDKgDUb3Lx+A9d4I75Kl/k7J7G/4uo6NnLa1qFfexz86KsppugmtzsqUDz2T6pl8RyEchUSDquJxAN8R0RYAXQD8h4jGEtFYZf98APsBpAP4DMAjlkcaBoz+YGr/8Uubeh7A1LROVYy9sjVmjTU3SOpuP67tWqbJlPf1STT9usT61Rz6+Xvy9KCL7Y+ra+bQcfXRH9AuweMxnnzgocSd3LKu00RxWr8qk7U1rBmPy9sEZmZQER66ejF63d9MJXdmTlOqUy5j5puY+QwzT2HmKcp+ZuZHmbk1M1/KzKn+DTt8NKodj9mP9MXkWy5zeYxa//b0oIsxfkh79Eg0VyK06upf02BiMl/vNGaN7YubTNRZq7T15mZc3rr8LkT98d3NmmlEPz5A75buzeDqEP1FundS+Cf3fhY1cAvgjVtd/58HmoxQtYBa4rw9ubnh/q4t6rpNYo1qxyNz8jDc0r2ZX+LzxKjxUm0/8HYq4AY1PHdHrIhBHS6yPyYizBrbB9+56VrpSbU4x79LrXjb3/KrB3o6HdsqobrT1BN+bpoIiIYBauhNrB+8nkqqWvExqGKisd5XLeq5nwI8kCS5WyA+Nhq7J16H564znuEx1BndAKjVMg1qVHGYd37zS4MNzxEXHYUVzwyw7G7C6DyZk4chsYHjP0+PxHoeZ810p5dSb14lxpbk+yp3Bu0b1bI38Krmj7sCVXUXAzNz8gPAIwNaez4oSFwtAK8yGgzni58f7mvJedy5+CL38yyteu5qn/uY39nTuPCmRQTc3C00eoFLcrdIlZjogDSSBIpaLRNNZGoWzG9G9XRKvP4yb1w/h26dFaH+zarGRWPZP67EuyPK2wu6t6yrO9b59aP6JaFxbc8rZWnbFswysxCLu4Rp9i5KezdkpFTTAPO5skiNL+rXqIL2FezV5El1N2sfAN7NxqjvDnyRiRXRCP7vaWaWJHcBo6bJsjJlD9kSYM149/80nj7PL13fAa/fYryoiRnafuwdm9Q2XOKwp4e2ivdGODf0qiXxutXi0DqhhkP1mb5UbjQ2ID42Gi+aGNXrbrI5VyqyEMu4gW2x0qLRuNr2l4GXlF8IfGk8TPJzAeCTu7vjhaHt3R5jduxJSx+rkZ68xv2ax5/+3beJAr0lyT0MrXr2KksH7RhXy7CyT7fTxf+Fp9LKA/2ScEcP87176lSNxYB2Cfj+wV7InDzMVD/2d+7ogoVPXmF//sXIHph2b3lJc3gX59vly1s3wL+Gd8Srwzuajk2vicFUyf93bTv88ojnagjtFM6+Mf7FN6gR57Zf/cu6VcJcIQJuc9GW1KS26ymiL2ls3Mj9xm2d8dm9vpf+PWlUOx6j+7fG7Ef64h+DLjbsrmr2Btupq62JEjkRoWV99xcwbzsA+EqSexhqXq9ahQftaBklJ7WhsZWbktbq567C0qf7486eLdAjsa7L43wRFUX4cmRPr6YqqBoX7dD3/6p2DXGNhyoHIuDvfRLdLmWoPdZItxZ1Mfexy+3Pq8VF4+5eLdCtheffSeuEGpg3rp/L7qPN67mfY99VwlGrEFxVzdzctZnhGIRbDRr1n9KVRNXf1cNu2hGGu1gdrEaVGI/VQFbo2qIuHh/YFv8c5nwRa9PQdb18Yv1q9tlUtX9uT4P+VGaOqhWAhToASe6VmlqC+KdBtUKrhBr4/L5k/NdN165mdauhTcOamHTzpT5VO4QCb6bVdnfoZc3Kqyh2vHady0ZefbKNIls105BLHQdWqaX+oZ3cN3a6KkwOVhLoH//nomrGxQ/zgG7mUYJzCbZJHfXCYb6XjaeLlK+WPGX+Dla9m/hyZE98ObKH4THa36f2x055YaCpnlHu7gr2/2cotrwyGLUCtHaCjIcWLrtpautXI6LPnwFvGsF9aTDfMOEah+mV1cftG9XErmPn7M/19cCNlJL3U4MuxrnCEnyfctDjeyU1qI6Mk3mIjiJ7rK4aGGu5aEPpoKsyMPqZp9/fAwu3HUMjTUPy80Pa41huAb74MxOA8x1FYzdVOBXhTSl4wRO2Krt61ePsC9ybUS0u2vTSmG5niI2igCV2QJK78EFcdBSKSsuCHYYl3PV51q8z60tfKH1S0M81pD53lRPiY6NxT6+WppJ7s7pVMWtsH8S6GV1bHoe5n0ZdjvHTv3e3X3Ca1a3mtG7vGGVtXzW5O/FT4cDTT3FRLe/68I+7ui0mzNlmf77llcGItl+Aw0t43ksLS/jaZev3/xvgtAarXlxMFNq6qdsMNnV+GXejVa9q39Cr+d899dYBgBgl8aoN1mrJPVrfeOfw2PUfSv83bFCjCmp7OfDMHbVB8tqOjdDZxDS1614ciD6t6uOOHo6NsPf2Nb8IiyffjioftObpI9zQRPdF1XcP9sIt3Zs5dJesFR9rv/upyPUpGP8LUnIXXiGyzYHjaTHt3f+6zvQ5vx3Vy9L541+9sSOa1fXQEAk1ubo/l3b+d0+F3a9H9UR+UanHY37ecAhpWWex98R5e3LXT+SmnWHUU++LUNKwZjxmGIxBsHKxmn5t/TNdglrKnzmmN37fdcKyWTyXPn1lUKZ7lpK78IrZ0j4Rmb7179e2QYX6dOvd1zfRsb3AQJm9WsS6Ovf42GjU87AaVOuEGnj2uvb2cQNVDNa5BRx/z+568piZA6giA4+sov/VaXvp3OSiZ40RfTWaVYWCwR0uQhtljqeW9avjfl3DMgCnD/9tbqYL0c5B1KZhDb8tfuOOJHdRKZWVOVaLBNrkmy/Di0MvQXJL4+6SnqZkVmlHj7q6+Hi60AG2hdN/ftjcbKRaL9/QwdTiJu6868VUz88MdpziIz422nBKaTM2TLjG3p2ztQ/VJiN6Oo7b2DNxiE9x+Isk90pM/UCb6eOtipQZFtS+/d5OjGaVutXj8FD/VoYJuVdSPVxk8jbe7Nw2nvRMqofuLb2fn37k5Ul4x4spnn2hHdV6T2/r6u7r16hiv4MycwekP6J7y7pY83z5ojLaCfj8saymtyS5V2Kv33Ipvn7A3Jwwb9zWGW0b1kD1CFlN6MVhl+Dju7uZnl45kGaO6WN63ECvsFkwxP1FSN+/Xuv3ZwbYH+snbquoIco4AjN3H70Mpnd21cWzLAQ6k5n6BBFRJhFtJaI0InKaq52IBhBRjrI/jYhesj5UYbVqcTHof7G55Q6v69QIS56+0rKSYiDFx0bhwX5Jum3RGOphNsRwEEmT1VXE04Mudjua2pWkBtWROXmYqVXN+rVtgO2vXmvqvP5Yec1b3hTDrmLmk272r2Lm6ysakBBW2/Wv0KoL9Zdnrw3PKacBz9V9vZLqOayX8M/rO2Dx9mP25+MGtsW4ge4n7LKCp1knVSGQ26VaRohw1K1FHae+053cLOWYMWmov0Nyou3+Z/YGY/yQ9ujdyrmqaeaYPg6L2Yzql4SZY7xvAA6UAe1td8TB6CWjMpvcGcBiItpARKNdHNOHiDYT0QIi8n2KPSEqMTMDhQDgl0cux5KnrzR93mBU3zx1jfk57NXo9IO5QtFvj/fDF/cbz02j6qNMSW327+kPZqtl+jHzYSJqCGAJEe1i5pWa/RsBtGTm80Q0FMAcAE73SMqFYTQAtGjhv8WdhQhHGyZc4/K2/8Whl6BxnXg89v2mgMVjtPyir8zWhzMY7RvVwtr9py17b6u5u0NSlY9A9nc0rplK7sx8WPl+gohmA+gJYKVmf67m8Xwi+piIGujr6Jl5KoCpAJCcnBwCtVJChA53k1M91N82l0uPxHoVnuOkvoeBViqrEtOdPZt7defwwtBLMOyyxrhtyhpL3v+t2zpj6+EcS85l5McxfVCsm2upWV3byOY+BovKBIrH5E5E1QFEMfM55fFgAK/pjmkE4DgzMxH1hK2655Q/AhaiMjNa6u2rB3p6tfj02hcGmjouEAO8jOZ2j4uJsrSL6i3dm/l18XmjBUEuvqgmVj17lcdpMPzJTMn9IgCzlStvDIDvmXkhEY0FAGaeAuBWAA8TUQmACwBGMIdCe7EQkU+/kLcn7iZL0/J3ateOLDW6jky6+VLTsYai5hYtLO4rj8mdmfcD6GywfYrm8YcAPrQ2NCFEMAVjagZtkfDOntIuVxGRMdxQCD9687bOuFBUEuwwAiYuJgpFJWX2ev5AkMFY1pPkLoQHRuuKRrLfHu+HhduOVXhQUCjMr1KZSXIXohLRry5l5OKLauJiC+dfF8EhyV2ISuK3x/uhed3ANfLp14U1Q8r61pHkLkQlYWbwjZW8qZaRGnfrhW8/IyFExJEO1NaR5C6E8AuvqmWk6G45qZYRQviFN9Uyo69ohf3ZebhL+rZbRpK7EMLPPBfL69eogs/uDf5C3pFEqmWEEH4mFenBIMldCCEikCR3IYSfSWtpMEhyF0L4RWyULb3ERUtyDwZpUBVC+MVNXZti38nzeOyqNsEOpVKS5C6E8Iu4mCg8P+SSYIdRaZmqliGiTCLaSkRpRJRqsJ+I6H0iSieiLUTUzfpQhRBCmOVNyf0q/ZqoGkNgWxC7LYBeAD5RvgshhAgCqxpUhwP4mm3WAqhDRI0tOrcQQggvmU3uDGAxEW0gotEG+5sCyNI8P6RsE0IIEQRmq2X6MfNhImoIYAkR7WLmld6+mXJhGA0ALVrIHBJCCOEvpkruzHxY+X4CwGwAPXWHHAbQXPO8mbJNf56pzJzMzMkJCd6t2C6EEMI8j8mdiKoTUU31MYDBALbpDpsL4F6l10xvADnMfNTyaIUQQphiplrmIgCzldXJYwB8z8wLiWgsADDzFADzAQwFkA4gH8BI/4QrhBDCDOIgLX1CRNkADvj48gYAXHXLDEXhFK/E6h8Sq/+EU7xWxNqSmT3WawctuVcEEaUyc9hM/hxO8Uqs/iGx+k84xRvIWGXiMCGEiECS3IUQIgKFa3KfGuwAvBRO8Uqs/iGx+k84xRuwWMOyzl0IIYR74VpyF0II4UbYJXciuo6IdivTC48PUgzTiegEEW3TbKtHREuIaK/yva6y3eV0yER0n3L8XiK6z0+xNiei34loBxFtJ6InQjVeIoononVEtFmJ9VVlexIRpSgxzSSiOGV7FeV5urI/UXOu55Xtu4noWqtj1bxPNBFtIqLfwiBWp6m7Q/FzoLxHHSL6iYh2EdFOIuoTirESUTvl96l+5RLRkyERKzOHzReAaAD7ALQCEAdgM4AOQYijP4BuALZptv0XwHjl8XgAryuPhwJYANtCkr0BpCjb6wHYr3yvqzyu64dYGwPopjyuCWAPgA6hGK/ynjWUx7EAUpQYfgQwQtk+BcDDyuNHAExRHo8AMFN53EH5bFQBkKR8ZqL99Fl4GsD3AH5TnodyrJkAGui2hdznQHmfrwA8qDyOA1AnVGPVxBwN4BiAlqEQq19+SD/+8voAWKR5/jyA54MUSyIck/tuAI2Vx40B7FYefwrgTv1xAO4E8Klmu8Nxfoz7VwCDQj1eANUAbIRtXYCTAGL0nwEAiwD0UR7HKMeR/nOhPc7iGJsBWAbgagC/Ke8dkrEq586Ec3IPuc8BgNoAMqC0CYZyrLr4BgP4M1RiDbdqmVCeWvgiLp9P5xhs0zYArmMO+M+iVAV0ha1EHJLxKtUcaQBOAFgCW0n2LDOXGLyvPSZlfw6A+oGKFcC7AJ4FUKY8rx/CsQLGU3eH4ucgCUA2gC+UKq9pZJvXKhRj1RoBYIbyOOixhltyDwtsu/SGVDckIqoB4GcATzJzrnZfKMXLzKXM3AW2UnFPAO2DHJIhIroewAlm3hDsWLzQj5m7wbZy2qNE1F+7M4Q+BzGwVXt+wsxdAeTBVrVhF0KxAgCUtpUbAczS7wtWrOGW3E1NLRwkx0lZfUr5fkLZ7irmgP0sRBQLW2L/jpl/CfV4AYCZzwL4HbaqjTpEpE5yp31fe0zK/toATgUo1ssB3EhEmQB+gK1q5r0QjRWAy6m7Q/FzcAjAIWZOUZ7/BFuyD8VYVUMAbGTm48rzoMcabsl9PYC2So+EONhug+YGOSbVXABqC/d9sNVtq9uNpkNeBGAwEdVVWtIHK9ssRUQE4HMAO5n57VCOl4gSiKiO8rgqbG0DO2FL8re6iFX9GW4FsFwpJc0FMELpoZIE29q+66yMlZmfZ+ZmzJwI2+dwOTPfHYqxAm6n7g65zwEzHwOQRUTtlE0DAewIxVg17kR5lYwaU3Bj9Vfjgh8bLYbC1uNjH4AXgxTDDABHARTDVsoYBVv96TIAewEsBVBPOZYAfKTEuxVAsuY8D8A2TXI6gJF+irUfbLeEWwCkKV9DQzFeAJcB2KTEug3AS8r2VrAlvHTYbnurKNvjlefpyv5WmnO9qPwMuwEM8fPnYQDKe8uEZKxKXJuVr+3q/04ofg6U9+gCIFX5LMyBrQdJqMZaHba7sNqabUGPVUaoCiFEBAq3ahkhhBAmSHIXQogIJMldCCEikCR3IYSIQJLchRAiAklyF0KICCTJXQghIpAkdyGEiED/D8dZsKfzDlUUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(loss_change)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x12e7c4780>]"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAHV5JREFUeJzt3XtwnFeZ5/Hv0zdJrVvr5rtk2UkwuUCcRLlCGENgIBkqUDUMJDvsQBYqwLAwMFs1Bbu1TC21VVtUTS2EZYZsFoZhBipLJdxSbIBACCFA4kS2A7k4ju92fJVk3e/d/ewf/Uppy5LVdiR3v92/T1WXut/3qPXE1fnp6LznnNfcHRERKS+RYhcgIiJLT+EuIlKGFO4iImVI4S4iUoYU7iIiZUjhLiJShgoKdzP7rJm9YGbPm9n9ZlY95/yHzazHzJ4NHh9dnnJFRKQQi4a7ma0FPg10ufsVQBS4Y56m33P3zcHjG0tcp4iInINCh2ViQI2ZxYAkcHT5ShIRkdcqtlgDdz9iZv8AHALGgUfc/ZF5mv65mb0FeBn4rLsfPtv7tra2emdn53mULCJSubZt29br7m2LtbPFth8wsybg+8AHgAHgAeBBd/9OXpsWYMTdJ83sY8AH3P1t87zX3cDdAB0dHdccPHjwHP6TRETEzLa5e9di7QoZlnk7sN/de9x9GvgBcFN+A3fvc/fJ4OU3gGvmeyN3v8/du9y9q61t0V88IiJyngoJ90PADWaWNDMDbgF25jcws9V5L2+fe15ERC6sQsbct5rZg8B2IA3sAO4zsy8C3e7+EPBpM7s9OH8K+PDylSwiIotZdMx9uXR1dXl3d3dRfraISFgt5Zi7iIiEjMJdRKQMKdxFRMpQ6ML9peND/MPPd9E/OlXsUkRESlbowv1A7yhfe2wPxwYnil2KiEjJCl24N9YkABgYV89dRGQhIQz3OABD49NFrkREpHSFLtxTyVy4D4wp3EVEFhLecFfPXURkQaEL95p4lHjUGFS4i4gsKHThbmY01iQ0LCMichahC3eAxpoYg5otIyKyoFCGeyqZ0LCMiMhZhDPca+IalhEROYtQhntjTVw9dxGRswhnuCfjDKrnLiKyoHCGe02c4ck06Uy22KWIiJSkUIZ7amYLgol0kSsRESlN4Qz3ZLB52JimQ4qIzCeU4T6zeZguqoqIzC+c4a79ZUREziqc4T7Tc9eMGRGReYUy3FMalhEROatQhvtMz12rVEVE5hfKcI9FI9RVxdRzFxFZQCjDHXK9d91HVURkfqEOd11QFRGZX2jDPZXU5mEiIgsJdbhrnruIyPxCG+6N2tNdRGRBIQ73BEPj07h7sUsRESk5IQ73OFOZLOPTmWKXIiJSckIb7qmkVqmKiCwkvOGuVaoiIgsqKNzN7LNm9oKZPW9m95tZ9ZzzVWb2PTPbY2ZbzaxzOYrNpy0IREQWtmi4m9la4NNAl7tfAUSBO+Y0+wjQ7+4XA18GvrTUhc7VqGEZEZEFFTosEwNqzCwGJIGjc86/B/h28PxB4BYzs6UpcX6v3rBDWxCIiMy1aLi7+xHgH4BDwDFg0N0fmdNsLXA4aJ8GBoGWpS31dDO32lPPXUTkTIUMyzSR65lvANYAtWb2wfP5YWZ2t5l1m1l3T0/P+bzFrNpElFjENOYuIjKPQoZl3g7sd/ced58GfgDcNKfNEaAdIBi6aQT65r6Ru9/n7l3u3tXW1vaaCjezYGdIhbuIyFyFhPsh4AYzSwbj6LcAO+e0eQj4UPD8fcCv/AIsHW3U5mEiIvMqZMx9K7mLpNuB54Lvuc/MvmhmtwfNvgm0mNke4G+Bzy1TvafRtr8iIvOLFdLI3f8e+Ps5h7+Qd34C+IslrKsgqZo4vSOaLSMiMldoV6hCbsaM7sYkInKmUIe7tv0VEZlf6MN9eCJNJqttf0VE8oU+3AGGNGNGROQ0oQ53bfsrIjK/sgh3LWQSETldqMP91W1/NWNGRCRfyMNdm4eJiMwn1OGuMXcRkfmFOtx1NyYRkfmFOtzj0Qi1iah67iIic4Q63EGrVEVE5hP+cE8m1HMXEZkj9OGeqonrPqoiInOEPtw1LCMicqbQh3tKd2MSETlD6MNd91EVETlT+MM9GWcqnWViOlPsUkRESkbowz0VbEGgcXcRkVeFPtxnV6lqxoyIyKzQh/vs/jLquYuIzAp9uL/ac1e4i4jMKJtw13RIEZFXhT7cNSwjInKm0Id7XVWMaMR0QVVEJE/ow93MSNXE6VfPXURkVujDHaC1rore4clilyEiUjLKItzb6qvoGVG4i4jMKItwX1FfxckhhbuIyIyyCPeZnru7F7sUEZGSUDbhPpXOMjSRLnYpIiIloWzCHaBHF1VFRIByCfc6hbuISL5Fw93MNpnZs3mPITP7zJw2W8xsMK/NF5av5DPN9tw1Y0ZEBIDYYg3cfRewGcDMosAR4IfzNH3C3d+9tOUVRsMyIiKnO9dhmVuAve5+cDmKOV+NNXES0YjCXUQkcK7hfgdw/wLnbjSzP5jZT83s8tdY1zkxM9rqqzg5PHEhf6yISMkqONzNLAHcDjwwz+ntwHp3vxL4X8CPFniPu82s28y6e3p6zqfeBbXWV6nnLiISOJee+63Adnc/MfeEuw+5+0jw/GEgbmat87S7z9273L2rra3tvIueT1udwl1EZMa5hPudLDAkY2arzMyC59cF79v32ssrXFt9Fb2aLSMiAhQwWwbAzGqBdwAfyzv2cQB3vxd4H/AJM0sD48AdfoH3Amirr6JvdIp0JkssWhbT90VEzltB4e7uo0DLnGP35j3/GvC1pS3t3LTVV+EOp0anWNFQXcxSRESKrmy6uDOrVE9q3F1EpHzCfUWDVqmKiMwom3Cf3V9G+7qLiJRRuGt/GRGRWWUT7tXxKPXVMc11FxGhjMIdgjsyKdxFRMos3LVKVUQEKLdwD+6lKiJS6coq3FfUV6vnLiJCmYV7W30VI5NpxqZ0o2wRqWxlF+4AvcNTRa5ERKS4yjLcddMOEal05RXudbqXqogIlFu4a5WqiAhQZuHeXJsgYuq5i4iUVbhHI0aLFjKJiJRXuAOs0BYEIiLlF+5apSoiUo7hrmEZEZEyDPdgWCabvaD35xYRKSllGe7prDMwPl3sUkREiqYswx00HVJEKlv5hbtWqYqIlF+4r2ioBqBnRPvLiEjlKrtw17CMiEgZhnttIkpNPKpwF5GKVnbhbma6UbaIVLyyC3fIDc2cVLiLSAUrz3DXKlURqXDlGe7aX0ZEKlxZhvuqxmoGxqYZndSNskWkMpVluHe21AJwoG+0yJWIiBRHeYZ7axKAA71jRa5ERKQ4yjPc1XMXkQq3aLib2SYzezbvMWRmn5nTxszsq2a2x8z+aGZXL1/Ji6utirGyoYr9vQp3EalMscUauPsuYDOAmUWBI8AP5zS7FbgkeFwPfD34WjSdLbUKdxGpWOc6LHMLsNfdD845/h7gXz3nKSBlZquXpMLztKG1lgMKdxGpUOca7ncA989zfC1wOO/1K8GxotnQWkvf6BRDE7pph4hUnoLD3cwSwO3AA+f7w8zsbjPrNrPunp6e832bgnS2BhdV1XsXkQp0Lj33W4Ht7n5innNHgPa81+uCY6dx9/vcvcvdu9ra2s6t0nO0IQh3jbuLSCU6l3C/k/mHZAAeAv4qmDVzAzDo7sdec3WvQUdzEjOFu4hUpkVnywCYWS3wDuBjecc+DuDu9wIPA7cBe4Ax4K4lr/QcVcejrGms0bCMiFSkgsLd3UeBljnH7s177sAnl7a0125Day37+7RKVUQqT1muUJ3R2Zpkf88Iud89IiKVo7zDvaWWoYk0/WOaDikilaWsw10zZkSkUlVEuOuiqohUmrIO9/bmJNGIaXdIEak4ZR3u8WiEdU01GpYRkYpT1uEO2h1SRCpT2Yf7zO6Qmg4pIpWkIsJ9dCpDz8hksUsREblgyj7cX90dUitVRaRylH24b2iZmes+UuRKREQunLIP97VNNcSjxn713EWkgpR9uEcjRkdzUguZRKSilH24QzBjRguZRKSCVES4d7bkwj2b1XRIEakMFRHuG9pqmZjOcnxootiliIhcEJUR7i3aQExEKktFhPtFK+oA+MMrg0WuRETkwqiIcF/ZUE3X+ia+98whjbuLSEWoiHAH+Pc3rudA3xi/29tb7FJERJZdxYT7u65YRUttgn978mCxSxERWXYVE+5VsSh/0dXOL3ee4NjgeLHLERFZVhUT7gB/eX0HDty/9VCxSxERWVYVFe7tzUm2vK6N+585zHQmW+xyRESWTUWFO8AHb1hPz/Akj7xwotiliIgsm4oL9y2bVrA2VcN3ntKFVREpXxUX7tGI8e+u7+DJfX3sOTlc7HJERJZFxYU7wAeubSceNe59fB9pjb2LSBmqyHBvravi/V3tPLjtFd75ld/w8HPHtHJVRMpKRYY7wH9/7xXc+8GrMTP++rvbuf0ff8uvd53EXSEvIuFnxQqzrq4u7+7uLsrPzpfJOj/acYQv//JlXukf543rGvnrLRfxp5etIhKxYpcnInIaM9vm7l2Ltqv0cJ8xlc7y/e2v8L8f38uBvjE2ttXy8T+5iPduXksiVrF/4IhIiVG4n6dM1nn4uWN8/dd7efHYEGsaq/nozRu547p2kolYscsTkQq3pOFuZingG8AVgAP/wd2fzDu/BfgxsD849AN3/+LZ3rNUw32Gu/P4yz3806/38vT+UzQl43z4pg186Kb1pJKJYpcnIhWq0HAvtCt6D/Azd3+fmSWA5DxtnnD3d59LkaXMzNiyaQVbNq1g28FT/NNje/nyL1/m/zyxj0+97WLuetMGDdeISMlaNJ3MrBF4C/BNAHefcveB5S6slFyzvplvfvhafvaZm7luQzP/46cv8c6v/IbHXjpZ7NJEROZVSNdzA9ADfMvMdpjZN8ysdp52N5rZH8zsp2Z2+dKWWRpev6qBf/7wtXzrrmsx4K5/eYa7vvW0VrqKSMlZdMzdzLqAp4A3uftWM7sHGHL3/5rXpgHIuvuImd0G3OPul8zzXncDdwN0dHRcc/BgePd3mUpn+fbvD3DPo7sZm0rz3qvW8plbXkdHy3wjViIiS2PJLqia2SrgKXfvDF7fDHzO3f/sLN9zAOhy9wXvaVfqF1QLdWp0insf38u3f3+ATNZ5/7XtfOptF7O6sabYpYlIGSo03BcdlnH348BhM9sUHLoFeHHOD1tlZhY8vy54375zrjqEmmsT/OfbLuU3f/dW7ryugwe6D/PmLz3GJ76zjd/u7tW2BiJSFIVOhdxMbipkAtgH3AV8AMDd7zWz/wh8AkgD48Dfuvvvz/ae5dJzn+vwqTH+7amDPNB9mP6xaTpbktx5XQfvvnINa1PqzYvIa6NFTEU2MZ3hZ88f57tbD/LMgX4ALl/TwDsuW8k7LlvJZasbCP7YEREpmMK9hOzrGeEXL57gFy+eYNuhftyhvbmGP3vDGt79xtVcvkZBLyKFUbiXqJ7hSR7deYKHnz/O7/b0ksk6nS1JbnvDarZsWsFVHSniUS2OEpH5KdxD4NToFI+8cJz/99wxfr+3j0zWqU1EufGiFm6+pI2bLmrh4hV16tWLyCyFe8gMTUzz5N4+ntjdw29e7uXQqTEAWmoTXLehmes2NHPDxhY2razXVsQiFWyp95aRZdZQHeedl6/inZevAuBg3yhb953iqf19bN13ip8+fxyAVDLO9RuauXFjCzdclAt79exFZC6Fe4la31LL+pZa3n9tO5CbYvn0/lM8ua+Pp/b18fMXTgDQWpfgTRe3cvMlbbz54lZWNVYXs2wRKREK95Bob07S3pzkz69ZB+TC/sl9ffxuTy+/29PLj589CsCG1lo2t6e4qiPF5vYUr1/VoN0rRSqQwj2kZsL+/V3tZLPOS8eH+e2eHroP9PPbPb38cMcRAKrjEW7Y2MKW17WxZdMKOlvn2/NNRMqNLqiWIXfn6OAEzx4a4JkDp3j85R72944C0NmS5O2XruTWN6zmqvaULs6KhIxmy8hpDvaN8utdPTy26yS/39PHVCbLqoZq3nXFKm69YhVdnc1EFfQiJU/hLgsampjmVztP8vBzx3j85R4m01la6xK88/JV3HrFam7Y2ExMC6lESpLCXQoyOpnmsV0n+enzx/nVzpOMT2doSsa5+ZI2br4kNwtHM3BESofCXc7Z+FSGx1/u4ZEXjvOb3b30jkwCcPGKuiDoW7l+Qwu1VboOL1IsCnd5TdxzM3Ce2N3DE7t7eXr/KSbTWeJR46qOJm6+uJVrOpt447oUdQp7kQtG4S5LamI6Q/eBfp7Y08Nvd/fywtEhACIGr1tZz1UdTVzVkeLqjhQbW+s0C0dkmSjcZVkNjE3x7OEBdhwaYMfhAZ491M/QRBqAhuoYV3U0sbk9xaWrG7h0dT3tTUkFvsgS0N4ysqxSyQRbNq1gy6YVAGSzzr7eEbYfGmDHoX62Hxzgq7t3M9N3SCaibFpVz5XrUnR1NtG1vlkXakWWkXrusmzGptLsPjHCS8eH2HlsmBePDfHHVwaYmM4CsDZVwxvXNdLRkqSj+dXH2lSNpmKKLEA9dym6ZCLGle0prmxPzR6bzmR58egQ3Qf72XbwFC8dG+bRnSeZymRn28SjRkdzkg2ttWxorWVVYw2NNXEaa+KkknHa6qpY35LUbpgiZ6GeuxRdNuucGJ7gUN8YB/vG2N83yv6eUfb3jnKgb5TJdPaM70kl41zd0cQ165tmN0lLJtRXkfKnnruERiRirG6sYXVjDddvbDntXDbrDE+kGRyfZmB8isHxaY4OjLPj0ADbDvbzq5dOAhCNGJeurp8N/CvWNtLelNSOmFKx1HOXUBscm2b74X52HOxn26F+dhwaYGwqA+Smaa5urGF9S5L1LUnWNSVZk6pmTWMNa1I1rG6s1ti+hI567lIRGpNx3rppBW8NZu2kM1l2nRjmxaNDHD41xsFTuaGeR144Qd/o1GnfWx2P8PpVDVyxtoEr1jRy6eoG1jXV0Fyb0Hi+hJ7CXcpKLBrh8jWNXL6m8Yxz41MZjg2Oc3RggiMDY+w+McLzRwf58Y6jfOepQ7PtqmIRVjdWszq4kFuTiFIdj1AVi5JKxnndyno2raqns6VWO2lKyVK4S8WoSUTZ2FbHxra6045ns87h/jF2HhsOwn+co4MTHB+cYF/vCBPTWcanM0xMZxiZTM/O3a+KRbh4RR3rmmpY2VDNyoZqVtRXsTZVw/rWWlY3VGvhlhSNwl0qXiRis/esXczEdIbdJ0bYdWKYXceHePnECAd6x9i6/xQDY9OntU3EIqxvzo33t9VX01aXoKWuita6KpqScRqT8dkpnnVVMQ0FyZJSuIucg+p4lDesa+QN684c9pmYznBiaIIj/ePs7xvNTevsHeXwqTGePTxA3+gUC81fqIpFWN+SpLMlN7e/s7U2uBBcy6qGag3/yDlTuIsskep4dPYvgJsubj3jfCbr9I9N0TsySf/oNIPj0wwG0zt7hifZ3zvGvt7cHbPyF3UlopHZoZ/aqig1iRjJeJS66hjrW5JsbK1jY1vul4CGgWSGwl3kAolGjNZgWOZsMlnn6MA4h06NcSiY7XPo1Cg9w5McG5xmbCrD2FSaofE049OZ2e+riUdZ2VBFXXWM+qo4ddUxGqrjtNYlaKlL0FpXRUtdFfXVMeqqYtRWxahNRKmrimlKaBlSuIuUmGjEaG9O0t6c5E1naefunByeZG/PCPt6RtnXM0rvyCQjk2mGJ6Y5fGqMwfFp+kanmJpnlW++6niE+uo49VUx6qtjNCYTNCXjpGripJKJ3HWB6lhwPk6yKko8EiEaMWJRIxox6qtipJIJLRwrEQp3kZAys9lZOjdddOYw0Ax3Z2QyTd/IFH2jkwxNpBmdTDM2mZv9M/MYnphmeCLN0ESawbEpDvaN0j86NbuVc6HqqmI01capr4oTjxqxaIRYxIhHIyQTueGkuqrcI5mIkojlppnmvuaeV8UiVMejVMUjxKMR4tHc98+8TyRiRM2IRCBqRiwSIRbN/aKJRyIankLhLlL2zCzXK6+O09m6+IygudKZbBD+6dN+GWQyTjrrZLJOOptlaCLNwOgUp8amGBibZnhimulM7tx0xhmdStM7MsnwRJrRqTQjE2nS2Qu3Qj4Wyf2FEYvkfuGYgXvul99MFdGIEbHcIxqBiBkGp81kikTAMMw47dxsi7zfKzNP586EuuPadj5688Zl+K98lcJdRM4qFo2QSiZIJRNL+r7uznTGmcpkmUpnmUxnmJzOMhk8n5jOMjGdIZ3NMpXO/ZJIZ5zpTJasO5ksZNzJZnO/ZNKZLOlscD7rkB+o7mR8pl2u7Yz84PWgXSabW/+QzQt+d3AcnNnjM7OfPO/7Z9/rjCevWuy6y1JQuItIUZgZiZjlxuiXP+sqTkFXPswsZWYPmtlLZrbTzG6cc97M7KtmtsfM/mhmVy9PuSIiUohCe+73AD9z9/eZWQJIzjl/K3BJ8Lge+HrwVUREimDRnruZNQJvAb4J4O5T7j4wp9l7gH/1nKeAlJmtXvJqRUSkIIUMy2wAeoBvmdkOM/uGmc295L4WOJz3+pXgmIiIFEEh4R4Drga+7u5XAaPA587nh5nZ3WbWbWbdPT095/MWIiJSgELC/RXgFXffGrx+kFzY5zsCtOe9XhccO4273+fuXe7e1dbWdj71iohIARYNd3c/Dhw2s03BoVuAF+c0ewj4q2DWzA3AoLsfW9pSRUSkUIXOlvkU8N1gpsw+4C4z+ziAu98LPAzcBuwBxoC7lqFWEREpUNFukG1mPcDB8/z2VqB3Ccu5EMJWs+pdXqp3eZVzvevdfdFx7aKF+2thZt2F3P27lIStZtW7vFTv8lK9Ba5QFRGRcFG4i4iUobCG+33FLuA8hK1m1bu8VO/yqvh6QznmLiIiZxfWnruIiJxF6MLdzN5lZruC7YXPaxuE5WRm/2xmJ83s+bxjzWb2CzPbHXxtKmaN+cys3cweM7MXzewFM/ub4HhJ1mxm1Wb2tJn9Iaj3vwXHN5jZ1uBz8b1gTUbJMLNosDfTT4LXJVuvmR0ws+fM7Fkz6w6OleTnYcZ825KXas1mtin4t515DJnZZ5a63lCFu5lFgX8kt8XwZcCdZnZZcas6w78A75pz7HPAo+5+CfAo57k3zzJJA//J3S8DbgA+GfyblmrNk8Db3P1KYDPwrmBV9JeAL7v7xUA/8JEi1jifvwF25r0u9Xrf6u6b86bnlernYcbMtuSvB64k929dkjW7+67g33YzcA25hZ8/ZKnrdffQPIAbgZ/nvf488Pli1zVPnZ3A83mvdwGrg+ergV3FrvEstf8YeEcYaiZ3X4Ht5O4d0AvE5vucFPtBbq+lR4G3AT8hd2vNUq73ANA651jJfh6ARmA/wTXEMNScV+OfAr9bjnpD1XMnvFsLr/RX99o5DqwsZjELMbNO4CpgKyVcczDE8SxwEvgFsBcYcPd00KTUPhdfAf4OmLlxZwulXa8Dj5jZNjO7OzhWsp8HFt6WvJRrnnEHcH/wfEnrDVu4h57nfi2X3BQlM6sDvg98xt2H8s+VWs3unvHcn7TrgOuA1xe5pAWZ2buBk+6+rdi1nIM3u/vV5IY/P2lmb8k/WWqfBwrYlrwEaya4znI78MDcc0tRb9jCvaCthUvQiZk7UwVfTxa5ntOYWZxcsH/X3X8QHC7pmgE8d0ewx8gNa6TMbGYjvFL6XLwJuN3MDgD/l9zQzD2Ubr24+5Hg60lyY8HXUdqfh4W2JS/lmiH3y3O7u58IXi9pvWEL92eAS4KZBglyf9I8VOSaCvEQ8KHg+YfIjWuXBDMzcrdQ3Onu/zPvVEnWbGZtZpYKnteQuz6wk1zIvy9oVjL1uvvn3X2du3eS+7z+yt3/khKt18xqzax+5jm5MeHnKdHPA5x1W/KSrTlwJ68OycBS11vsCwrncQHiNuBlcuOs/6XY9cxT3/3AMWCaXI/iI+TGWB8FdgO/BJqLXWdevW8m9+ffH4Fng8dtpVoz8EZgR1Dv88AXguMbgafJbTv9AFBV7FrnqX0L8JNSrjeo6w/B44WZ/8dK9fOQV/dmoDv4XPwIaCrlmoFaoA9ozDu2pPVqhaqISBkK27CMiIgUQOEuIlKGFO4iImVI4S4iUoYU7iIiZUjhLiJShhTuIiJlSOEuIlKG/j9HwIOLLxWnSQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(val_loss_change)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Overall the loss is decreasing and everything looks nice!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
